aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--python/pyspark/sql.py22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala46
4 files changed, 68 insertions, 44 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d87783efd2..0d8453fb18 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
- // TODO: Figure out why flatMap is necessay for pyspark
iter.flatMap { row =>
unpickle.loads(row) match {
+ // in case of objects are pickled in batch mode
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // Incase the partition doesn't have a collection
+ // not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index cb83e89176..a6b3277db3 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -47,12 +47,14 @@ class SQLContext:
...
ValueError:...
- >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
- ... "boolean" : True}])
+ >>> from datetime import datetime
+ >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
+ ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},
+ ... "list": [1, 2, 3]}])
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
- ... x.boolean))
+ ... x.boolean, x.time, x.dict["a"], x.list))
>>> srdd.collect()[0]
- (1, u'string', 1.0, 1, True)
+ (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3])
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
@@ -88,13 +90,13 @@ class SQLContext:
>>> from array import array
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
- >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
- ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
+ >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}},
+ ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}]
True
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
- >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
- ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
+ >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]},
+ ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}]
True
"""
if (rdd.__class__ is SchemaRDD):
@@ -509,8 +511,8 @@ def _test():
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
- {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)},
- {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}])
+ {"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
+ {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4abd89955b..c178dad662 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
+ case c: java.math.BigDecimal => DecimalType
+ case c: java.sql.Timestamp => TimestampType
+ case c: java.util.Calendar => TimestampType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
- case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
@@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext)
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
- val schema = rdd.first().map { case (fieldName, obj) =>
+ val firstRow = rdd.first()
+ val schema = firstRow.map { case (fieldName, obj) =>
AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq
- val rowRdd = rdd.mapPartitions { iter =>
+ def needTransform(obj: Any): Boolean = obj match {
+ case c: java.util.List[_] => true
+ case c: java.util.Map[_, _] => true
+ case c if c.getClass.isArray => true
+ case c: java.util.Calendar => true
+ case c => false
+ }
+
+ // convert JList, JArray into Seq, convert JMap into Map
+ // convert Calendar into Timestamp
+ def transform(obj: Any): Any = obj match {
+ case c: java.util.List[_] => c.map(transform).toSeq
+ case c: java.util.Map[_, _] => c.map {
+ case (key, value) => (key, transform(value))
+ }.toMap
+ case c if c.getClass.isArray =>
+ c.asInstanceOf[Array[_]].map(transform).toSeq
+ case c: java.util.Calendar =>
+ new java.sql.Timestamp(c.getTime().getTime())
+ case c => c
+ }
+
+ val need = firstRow.exists {case (key, value) => needTransform(value)}
+ val transformed = if (need) {
+ rdd.mapPartitions { iter =>
+ iter.map {
+ m => m.map {case (key, value) => (key, transform(value))}
+ }
+ }
+ } else rdd
+
+ val rowRdd = transformed.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 31d27bb4f0..019ff9d300 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType}
+import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.api.java.JavaRDD
@@ -376,39 +376,27 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ def toJava(obj: Any, dataType: DataType): Any = dataType match {
+ case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
+ case array: ArrayType => obj match {
+ case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
+ case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
+ case arr if arr != null && arr.getClass.isArray =>
+ arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+ case other => other
+ }
+ case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
+ case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+ }.asJava
+ // Pyrolite can handle Timestamp
+ case other => obj
+ }
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
val fields = structType.fields.map(field => (field.name, field.dataType))
val map: JMap[String, Any] = new java.util.HashMap
row.zip(fields).foreach {
- case (obj, (attrName, dataType)) =>
- dataType match {
- case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
- case array @ ArrayType(struct: StructType) =>
- val arrayValues = obj match {
- case seq: Seq[Any] =>
- seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
- case list: JList[_] =>
- list.map(element => rowToMap(element.asInstanceOf[Row], struct))
- case set: JSet[_] =>
- set.map(element => rowToMap(element.asInstanceOf[Row], struct))
- case arr if arr != null && arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map {
- element => rowToMap(element.asInstanceOf[Row], struct)
- }
- case other => other
- }
- map.put(attrName, arrayValues)
- case array: ArrayType => {
- val arrayValues = obj match {
- case seq: Seq[Any] => seq.asJava
- case other => other
- }
- map.put(attrName, arrayValues)
- }
- case other => map.put(attrName, obj)
- }
+ case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
}
-
map
}