aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-06 11:08:12 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-06 11:08:12 -0700
commit48789117c2dd6d38e0bd8d21cdbcb989913205a6 (patch)
tree387fbb2c70cd2abe993e6009157dd76d6afd0f60
parent09f7e4587bbdf74207d2629e8c1314f93d865999 (diff)
downloadspark-48789117c2dd6d38e0bd8d21cdbcb989913205a6.tar.gz
spark-48789117c2dd6d38e0bd8d21cdbcb989913205a6.tar.bz2
spark-48789117c2dd6d38e0bd8d21cdbcb989913205a6.zip
[SPARK-2875] [PySpark] [SQL] handle null in schemaRDD()
Handle null in schemaRDD during converting them into Python. Author: Davies Liu <davies.liu@gmail.com> Closes #1802 from davies/json and squashes the following commits: 88e6b1f [Davies Liu] handle null in schemaRDD()
-rw-r--r--python/pyspark/sql.py7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala27
2 files changed, 23 insertions, 11 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f1093701dd..adc56e7ec0 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1231,6 +1231,13 @@ class SQLContext:
... "field3.field5[0] as f3 from table3")
>>> srdd6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
+
+ >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
+ ... '{"key0": {"key1": "value1"}}'])).collect()
+ [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
+ ... '{"key0": {"key1": "value1"}}'])).collect()
+ [Row(key0=None), Row(key0=Row(key1=u'value1'))]
"""
def func(iterator):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 57df79321b..33b2ed1b3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -382,21 +382,26 @@ class SchemaRDD(
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
import scala.collection.Map
- def toJava(obj: Any, dataType: DataType): Any = dataType match {
- case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct)
- case array: ArrayType => obj match {
- case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
- case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
- case arr if arr != null && arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
- case other => other
- }
- case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
+ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ case (null, _) => null
+
+ case (obj: Row, struct: StructType) => rowToArray(obj, struct)
+
+ case (seq: Seq[Any], array: ArrayType) =>
+ seq.map(x => toJava(x, array.elementType)).asJava
+ case (list: JList[_], array: ArrayType) =>
+ list.map(x => toJava(x, array.elementType)).asJava
+ case (arr, array: ArrayType) if arr.getClass.isArray =>
+ arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+
+ case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+
// Pyrolite can handle Timestamp
- case other => obj
+ case (other, _) => other
}
+
def rowToArray(row: Row, structType: StructType): Array[Any] = {
val fields = structType.fields.map(field => field.dataType)
row.zip(fields).map {