aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorKan Zhang <kzhang@apache.org>2014-06-16 11:11:29 -0700
committerReynold Xin <rxin@apache.org>2014-06-16 11:11:29 -0700
commit4fdb491775bb9c4afa40477dc0069ff6fcadfe25 (patch)
treec996de6ecbf6f913b3e7bc8a45f7801aa58266ac /sql
parent716c88aa147762f7f617adf34a17edd681d9a4ff (diff)
downloadspark-4fdb491775bb9c4afa40477dc0069ff6fcadfe25.tar.gz
spark-4fdb491775bb9c4afa40477dc0069ff6fcadfe25.tar.bz2
spark-4fdb491775bb9c4afa40477dc0069ff6fcadfe25.zip
[SPARK-2010] Support for nested data in PySpark SQL
JIRA issue https://issues.apache.org/jira/browse/SPARK-2010 This PR adds support for nested collection types in PySpark SQL, including array, dict, list, set, and tuple. Example, ``` >>> from array import array >>> from pyspark.sql import SQLContext >>> sqlCtx = SQLContext(sc) >>> rdd = sc.parallelize([ ... {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] True >>> rdd = sc.parallelize([ ... {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == \ ... [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] True ``` Author: Kan Zhang <kzhang@apache.org> Closes #1041 from kanzhang/SPARK-2010 and squashes the following commits: 1b2891d [Kan Zhang] [SPARK-2010] minor doc change and adding a TODO 504f27e [Kan Zhang] [SPARK-2010] Support for nested data in PySpark SQL
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala29
1 files changed, 19 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 378ff54531..131c130bbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* Peek at the first row of the RDD and infer its schema.
- * TODO: We only support primitive types, add support for nested types.
+ * TODO: consolidate this with the type system developed in SPARK-2060.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
+ import scala.collection.JavaConversions._
+ def typeFor(obj: Any): DataType = obj match {
+ case c: java.lang.String => StringType
+ case c: java.lang.Integer => IntegerType
+ case c: java.lang.Long => LongType
+ case c: java.lang.Double => DoubleType
+ case c: java.lang.Boolean => BooleanType
+ case c: java.util.List[_] => ArrayType(typeFor(c.head))
+ case c: java.util.Set[_] => ArrayType(typeFor(c.head))
+ case c: java.util.Map[_, _] =>
+ val (key, value) = c.head
+ MapType(typeFor(key), typeFor(value))
+ case c if c.getClass.isArray =>
+ val elem = c.asInstanceOf[Array[_]].head
+ ArrayType(typeFor(elem))
+ case c => throw new Exception(s"Object of type $c cannot be used")
+ }
val schema = rdd.first().map { case (fieldName, obj) =>
- val dataType = obj.getClass match {
- case c: Class[_] if c == classOf[java.lang.String] => StringType
- case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
- case c: Class[_] if c == classOf[java.lang.Long] => LongType
- case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
- case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
- case c => throw new Exception(s"Object of type $c cannot be used")
- }
- AttributeReference(fieldName, dataType, true)()
+ AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq
val rowRdd = rdd.mapPartitions { iter =>