aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-20 15:35:05 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-20 15:35:05 -0800
commit5b0a42cb17b840c82d3f8a5ad061d99e261ceadf (patch)
treedbdc285db33b30e2797400373b43568673d4741c /python/pyspark/sql/types.py
parent4a17eedb16343413e5b6f8bb58c6da8952ee7ab6 (diff)
downloadspark-5b0a42cb17b840c82d3f8a5ad061d99e261ceadf.tar.gz
spark-5b0a42cb17b840c82d3f8a5ad061d99e261ceadf.tar.bz2
spark-5b0a42cb17b840c82d3f8a5ad061d99e261ceadf.zip
[SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from pandas and tuple/list
Fix createDataFrame() from pandas DataFrame (not tested by jenkins, depends on SPARK-5693). It also support to create DataFrame from plain tuple/list without column names, `_1`, `_2` will be used as column names. Author: Davies Liu <davies@databricks.com> Closes #4679 from davies/pandas and squashes the following commits: c0cbe0b [Davies Liu] fix tests 8466d1d [Davies Liu] fix create DataFrame from pandas
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py26
1 files changed, 9 insertions, 17 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9409c6f9f6..b6e41cf0b2 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -604,7 +604,7 @@ def _infer_type(obj):
ExamplePointUDT
"""
if obj is None:
- raise ValueError("Can not infer type for None")
+ return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
@@ -637,15 +637,14 @@ def _infer_schema(row):
if isinstance(row, dict):
items = sorted(row.items())
- elif isinstance(row, tuple):
+ elif isinstance(row, (tuple, list)):
if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
elif hasattr(row, "__FIELDS__"): # Row
items = zip(row.__FIELDS__, tuple(row))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
- items = row
else:
- raise ValueError("Can't infer schema from tuple")
+ names = ['_%d' % i for i in range(1, len(row) + 1)]
+ items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
@@ -812,17 +811,10 @@ def _create_converter(dataType):
if obj is None:
return
- if isinstance(obj, tuple):
- if hasattr(obj, "_fields"):
- d = dict(zip(obj._fields, obj))
- elif hasattr(obj, "__FIELDS__"):
- d = dict(zip(obj.__FIELDS__, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- d = dict(obj)
- else:
- raise ValueError("unexpected tuple: %s" % str(obj))
+ if isinstance(obj, (tuple, list)):
+ return tuple(conv(v) for v, conv in zip(obj, converters))
- elif isinstance(obj, dict):
+ if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
@@ -1022,7 +1014,7 @@ def _verify_type(obj, dataType):
return
_type = type(dataType)
- assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+ assert _type in _acceptable_types, "unknown datatype: %s" % dataType
# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
@@ -1040,7 +1032,7 @@ def _verify_type(obj, dataType):
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with"
+ raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)