diff options
-rw-r--r-- | python/pyspark/sql.py | 17 | ||||
-rw-r--r-- | python/pyspark/tests.py | 8 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala | 2 |
3 files changed, 19 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ae288471b0..1ee0b28a32 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -788,8 +788,9 @@ def _create_converter(dataType): return lambda row: map(conv, row) elif isinstance(dataType, MapType): - conv = _create_converter(dataType.valueType) - return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) elif isinstance(dataType, NullType): return lambda x: None @@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType): elif isinstance(dataType, MapType): k, v = obj.iteritems().next() - return MapType(_infer_type(k), + return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) elif isinstance(dataType, StructType): @@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt): elif isinstance(dt, ArrayType): return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct_or_date(dt.valueType) + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True elif isinstance(dt, UserDefinedType): @@ -1148,12 +1149,13 @@ def _create_cls(dataType): return List elif isinstance(dataType, MapType): - cls = _create_cls(dataType.valueType) + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) def Dict(d): if d is None: return - return dict((k, _create_object(cls, v)) for k, v in d.items()) + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) return Dict @@ -1164,7 +1166,8 @@ def _create_cls(dataType): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - raise Exception("unexpected data type: %s" % dataType) + # no wrapper for primitive types + return lambda x: x class Row(tuple): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bca52a7ce6..b474fcf5bf 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -923,6 +923,14 @@ class SQLTests(ReusedPySparkTestCase): result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.first()[0]) + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + k, v = srdd.first().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 2b4a88d5e8..5a41399971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -132,7 +132,7 @@ object EvaluatePython { arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) }.asJava case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) |