aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py48
1 files changed, 28 insertions, 20 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 23d9adb0da..932686e5e4 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType):
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
- False
+ True
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
@@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType):
True
"""
if isinstance(dataType, StructType):
- return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ # convert namedtuple or Row into tuple
+ return True
elif isinstance(dataType, ArrayType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
@@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType):
if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
- converters = [_python_to_sql_converter(t) for t in types]
-
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(c(obj.get(n)) for n, c in zip(names, converters))
- elif isinstance(obj, tuple):
- if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
- return tuple(c(v) for c, v in zip(converters, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
- d = dict(obj)
- return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ if any(_need_python_to_sql_conversion(t) for t in types):
+ converters = [_python_to_sql_converter(t) for t in types]
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif obj is not None:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ else:
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(obj.get(n) for n in names)
else:
- return tuple(c(v) for c, v in zip(converters, obj))
- elif obj is not None:
- raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return tuple(obj)
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
@@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType):
_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s" % dataType
- # subclass of them can not be deserialized in JVM
- if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept object in type %s"
- % (dataType, type(obj)))
+ if _type is StructType:
+ if not isinstance(obj, (tuple, list)):
+ raise TypeError("StructType can not accept object in type %s" % type(obj))
+ else:
+ # subclass of them can not be deserialized in JVM
+ if type(obj) not in _acceptable_types[_type]:
+ raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
if isinstance(dataType, ArrayType):
for i in obj: