aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 27f1d2ddf9..46540ca3f1 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -498,10 +498,7 @@ def _infer_schema(row):
def _create_converter(obj, dataType):
"""Create an converter to drop the names of fields in obj """
- if not _has_struct(dataType):
- return lambda x: x
-
- elif isinstance(dataType, ArrayType):
+ if isinstance(dataType, ArrayType):
conv = _create_converter(obj[0], dataType.elementType)
return lambda row: map(conv, row)
@@ -510,6 +507,9 @@ def _create_converter(obj, dataType):
conv = _create_converter(value, dataType.valueType)
return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ elif not isinstance(dataType, StructType):
+ return lambda x: x
+
# dataType must be StructType
names = [f.name for f in dataType.fields]
@@ -529,8 +529,7 @@ def _create_converter(obj, dataType):
elif hasattr(obj, "__dict__"): # object
conv = lambda o: [o.__dict__.get(n, None) for n in names]
- nested = any(_has_struct(f.dataType) for f in dataType.fields)
- if not nested:
+ if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
return conv
row = conv(obj)
@@ -1037,7 +1036,8 @@ class SQLContext:
raise ValueError("The first row in RDD is empty, "
"can not infer schema")
if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated")
+ warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+ "please use pyspark.Row instead")
schema = _infer_schema(first)
rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))