diff options
author | Davies Liu <davies@databricks.com> | 2014-12-16 21:23:28 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-16 21:23:28 -0800 |
commit | ec5c4279edabd5ea2b187aff6662ac07ed825b08 (patch) | |
tree | fdb5b450f4519b2ac88e2713128e3e8a675702a5 /python/pyspark/sql.py | |
parent | 770d8153a5fe400147cc597c8b4b703f0aa00c22 (diff) | |
download | spark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.tar.gz spark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.tar.bz2 spark-ec5c4279edabd5ea2b187aff6662ac07ed825b08.zip |
[SPARK-4866] support StructType as key in MapType
This PR brings support of using StructType(and other hashable types) as key in MapType.
Author: Davies Liu <davies@databricks.com>
Closes #3714 from davies/fix_struct_in_map and squashes the following commits:
68585d7 [Davies Liu] fix primitive types in MapType
9601534 [Davies Liu] support StructType as key in MapType
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r-- | python/pyspark/sql.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ae288471b0..1ee0b28a32 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -788,8 +788,9 @@ def _create_converter(dataType): return lambda row: map(conv, row) elif isinstance(dataType, MapType): - conv = _create_converter(dataType.valueType) - return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) elif isinstance(dataType, NullType): return lambda x: None @@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType): elif isinstance(dataType, MapType): k, v = obj.iteritems().next() - return MapType(_infer_type(k), + return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) elif isinstance(dataType, StructType): @@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt): elif isinstance(dt, ArrayType): return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct_or_date(dt.valueType) + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True elif isinstance(dt, UserDefinedType): @@ -1148,12 +1149,13 @@ def _create_cls(dataType): return List elif isinstance(dataType, MapType): - cls = _create_cls(dataType.valueType) + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) def Dict(d): if d is None: return - return dict((k, _create_object(cls, v)) for k, v in d.items()) + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) return Dict @@ -1164,7 +1166,8 @@ def _create_cls(dataType): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - raise Exception("unexpected data type: %s" % dataType) + # no wrapper for primitive types + return lambda x: x class Row(tuple): |