aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala29
2 files changed, 40 insertions, 11 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e344610b1f..c31d49ce83 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -77,12 +77,25 @@ class SQLContext:
"""Infer and apply a schema to an RDD of L{dict}s.
We peek at the first row of the RDD to determine the fields names
- and types, and then use that to extract all the dictionaries.
+ and types, and then use that to extract all the dictionaries. Nested
+ collections are supported, which include array, dict, list, set, and
+ tuple.
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}]
True
+
+ >>> from array import array
+ >>> srdd = sqlCtx.inferSchema(nestedRdd1)
+ >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
+ ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
+ True
+
+ >>> srdd = sqlCtx.inferSchema(nestedRdd2)
+ >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
+ ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
+ True
"""
if (rdd.__class__ is SchemaRDD):
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
@@ -413,6 +426,7 @@ class SchemaRDD(RDD):
def _test():
import doctest
+ from array import array
from pyspark.context import SparkContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
@@ -422,6 +436,12 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
+ globs['nestedRdd1'] = sc.parallelize([
+ {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
+ {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
+ globs['nestedRdd2'] = sc.parallelize([
+ {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
+ {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 378ff54531..131c130bbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* Peek at the first row of the RDD and infer its schema.
- * TODO: We only support primitive types, add support for nested types.
+ * TODO: consolidate this with the type system developed in SPARK-2060.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
+ import scala.collection.JavaConversions._
+ def typeFor(obj: Any): DataType = obj match {
+ case c: java.lang.String => StringType
+ case c: java.lang.Integer => IntegerType
+ case c: java.lang.Long => LongType
+ case c: java.lang.Double => DoubleType
+ case c: java.lang.Boolean => BooleanType
+ case c: java.util.List[_] => ArrayType(typeFor(c.head))
+ case c: java.util.Set[_] => ArrayType(typeFor(c.head))
+ case c: java.util.Map[_, _] =>
+ val (key, value) = c.head
+ MapType(typeFor(key), typeFor(value))
+ case c if c.getClass.isArray =>
+ val elem = c.asInstanceOf[Array[_]].head
+ ArrayType(typeFor(elem))
+ case c => throw new Exception(s"Object of type $c cannot be used")
+ }
val schema = rdd.first().map { case (fieldName, obj) =>
- val dataType = obj.getClass match {
- case c: Class[_] if c == classOf[java.lang.String] => StringType
- case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
- case c: Class[_] if c == classOf[java.lang.Long] => LongType
- case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
- case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
- case c => throw new Exception(s"Object of type $c cannot be used")
- }
- AttributeReference(fieldName, dataType, true)()
+ AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq
val rowRdd = rdd.mapPartitions { iter =>