aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-12-16 21:23:28 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-16 21:23:28 -0800
commitec5c4279edabd5ea2b187aff6662ac07ed825b08 (patch)
treefdb5b450f4519b2ac88e2713128e3e8a675702a5 /python/pyspark
parent770d8153a5fe400147cc597c8b4b703f0aa00c22 (diff)
downloadspark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.tar.gz
spark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.tar.bz2
spark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.zip
[SPARK-4866] support StructType as key in MapType
This PR brings support of using StructType(and other hashable types) as key in MapType. Author: Davies Liu <davies@databricks.com> Closes #3714 from davies/fix_struct_in_map and squashes the following commits: 68585d7 [Davies Liu] fix primitive types in MapType 9601534 [Davies Liu] support StructType as key in MapType
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql.py17
-rw-r--r--python/pyspark/tests.py8
2 files changed, 18 insertions, 7 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index ae288471b0..1ee0b28a32 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -788,8 +788,9 @@ def _create_converter(dataType):
return lambda row: map(conv, row)
elif isinstance(dataType, MapType):
- conv = _create_converter(dataType.valueType)
- return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ kconv = _create_converter(dataType.keyType)
+ vconv = _create_converter(dataType.valueType)
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
elif isinstance(dataType, NullType):
return lambda x: None
@@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType):
elif isinstance(dataType, MapType):
k, v = obj.iteritems().next()
- return MapType(_infer_type(k),
+ return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
elif isinstance(dataType, StructType):
@@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt):
elif isinstance(dt, ArrayType):
return _has_struct_or_date(dt.elementType)
elif isinstance(dt, MapType):
- return _has_struct_or_date(dt.valueType)
+ return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
elif isinstance(dt, UserDefinedType):
@@ -1148,12 +1149,13 @@ def _create_cls(dataType):
return List
elif isinstance(dataType, MapType):
- cls = _create_cls(dataType.valueType)
+ kcls = _create_cls(dataType.keyType)
+ vcls = _create_cls(dataType.valueType)
def Dict(d):
if d is None:
return
- return dict((k, _create_object(cls, v)) for k, v in d.items())
+ return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
return Dict
@@ -1164,7 +1166,8 @@ def _create_cls(dataType):
return lambda datum: dataType.deserialize(datum)
elif not isinstance(dataType, StructType):
- raise Exception("unexpected data type: %s" % dataType)
+ # no wrapper for primitive types
+ return lambda x: x
class Row(tuple):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bca52a7ce6..b474fcf5bf 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -923,6 +923,14 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd)
+ k, v = srdd.first().m.items()[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)