aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala46
2 files changed, 54 insertions, 32 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4abd89955b..c178dad662 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
+ case c: java.math.BigDecimal => DecimalType
+ case c: java.sql.Timestamp => TimestampType
+ case c: java.util.Calendar => TimestampType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
- case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
@@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext)
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
- val schema = rdd.first().map { case (fieldName, obj) =>
+ val firstRow = rdd.first()
+ val schema = firstRow.map { case (fieldName, obj) =>
AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq
- val rowRdd = rdd.mapPartitions { iter =>
+ def needTransform(obj: Any): Boolean = obj match {
+ case c: java.util.List[_] => true
+ case c: java.util.Map[_, _] => true
+ case c if c.getClass.isArray => true
+ case c: java.util.Calendar => true
+ case c => false
+ }
+
+ // convert JList, JArray into Seq, convert JMap into Map
+ // convert Calendar into Timestamp
+ def transform(obj: Any): Any = obj match {
+ case c: java.util.List[_] => c.map(transform).toSeq
+ case c: java.util.Map[_, _] => c.map {
+ case (key, value) => (key, transform(value))
+ }.toMap
+ case c if c.getClass.isArray =>
+ c.asInstanceOf[Array[_]].map(transform).toSeq
+ case c: java.util.Calendar =>
+ new java.sql.Timestamp(c.getTime().getTime())
+ case c => c
+ }
+
+ val need = firstRow.exists {case (key, value) => needTransform(value)}
+ val transformed = if (need) {
+ rdd.mapPartitions { iter =>
+ iter.map {
+ m => m.map {case (key, value) => (key, transform(value))}
+ }
+ }
+ } else rdd
+
+ val rowRdd = transformed.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 31d27bb4f0..019ff9d300 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType}
+import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
import org.apache.spark.api.java.JavaRDD
@@ -376,39 +376,27 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ def toJava(obj: Any, dataType: DataType): Any = dataType match {
+ case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
+ case array: ArrayType => obj match {
+ case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
+ case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
+ case arr if arr != null && arr.getClass.isArray =>
+ arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+ case other => other
+ }
+ case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
+ case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+ }.asJava
+ // Pyrolite can handle Timestamp
+ case other => obj
+ }
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
val fields = structType.fields.map(field => (field.name, field.dataType))
val map: JMap[String, Any] = new java.util.HashMap
row.zip(fields).foreach {
- case (obj, (attrName, dataType)) =>
- dataType match {
- case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
- case array @ ArrayType(struct: StructType) =>
- val arrayValues = obj match {
- case seq: Seq[Any] =>
- seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
- case list: JList[_] =>
- list.map(element => rowToMap(element.asInstanceOf[Row], struct))
- case set: JSet[_] =>
- set.map(element => rowToMap(element.asInstanceOf[Row], struct))
- case arr if arr != null && arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map {
- element => rowToMap(element.asInstanceOf[Row], struct)
- }
- case other => other
- }
- map.put(attrName, arrayValues)
- case array: ArrayType => {
- val arrayValues = obj match {
- case seq: Seq[Any] => seq.asJava
- case other => other
- }
- map.put(attrName, arrayValues)
- }
- case other => map.put(attrName, obj)
- }
+ case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
}
-
map
}