From 5b0a42cb17b840c82d3f8a5ad061d99e261ceadf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 20 Feb 2015 15:35:05 -0800 Subject: [SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from pandas and tuple/list Fix createDataFrame() from pandas DataFrame (not tested by jenkins, depends on SPARK-5693). It also support to create DataFrame from plain tuple/list without column names, `_1`, `_2` will be used as column names. Author: Davies Liu Closes #4679 from davies/pandas and squashes the following commits: c0cbe0b [Davies Liu] fix tests 8466d1d [Davies Liu] fix create DataFrame from pandas --- python/pyspark/sql/context.py | 12 ++++++++++-- python/pyspark/sql/tests.py | 2 +- python/pyspark/sql/types.py | 26 +++++++++----------------- 3 files changed, 20 insertions(+), 20 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 3f168f718b..313f15e6d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -351,6 +351,8 @@ class SQLContext(object): :return: a DataFrame >>> l = [('Alice', 1)] + >>> sqlCtx.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] @@ -359,6 +361,8 @@ class SQLContext(object): [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) + >>> sqlCtx.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] @@ -377,14 +381,17 @@ class SQLContext(object): >>> df3 = sqlCtx.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] + + >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if has_pandas and isinstance(data, pandas.DataFrame): - data = self._sc.parallelize(data.to_records(index=False)) if schema is None: schema = list(data.columns) + data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): try: @@ -399,7 +406,8 @@ class SQLContext(object): if isinstance(schema, (list, tuple)): first = data.first() if not isinstance(first, (list, tuple)): - raise ValueError("each row in `rdd` should be list or tuple") + raise ValueError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) row_cls = Row(*schema) schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8e1bb36598..39071e7e35 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,7 +186,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual("2", row.d) def test_infer_schema(self): - d = [Row(l=[], d={}), + d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) df = self.sqlCtx.createDataFrame(rdd) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9409c6f9f6..b6e41cf0b2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -604,7 +604,7 @@ def _infer_type(obj): ExamplePointUDT """ if obj is None: - raise ValueError("Can not infer type for None") + return NullType() if hasattr(obj, '__UDT__'): return obj.__UDT__ @@ -637,15 +637,14 @@ def _infer_schema(row): if isinstance(row, dict): items = sorted(row.items()) - elif isinstance(row, tuple): + elif isinstance(row, (tuple, list)): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) elif hasattr(row, "__FIELDS__"): # Row items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row else: - raise ValueError("Can't infer schema from tuple") + names = ['_%d' % i for i in range(1, len(row) + 1)] + items = zip(names, row) elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) @@ -812,17 +811,10 @@ def _create_converter(dataType): if obj is None: return - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) + if isinstance(obj, (tuple, list)): + return tuple(conv(v) for v, conv in zip(obj, converters)) - elif isinstance(obj, dict): + if isinstance(obj, dict): d = obj elif hasattr(obj, "__dict__"): # object d = obj.__dict__ @@ -1022,7 +1014,7 @@ def _verify_type(obj, dataType): return _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType + assert _type in _acceptable_types, "unknown datatype: %s" % dataType # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: @@ -1040,7 +1032,7 @@ def _verify_type(obj, dataType): elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" + raise ValueError("Length of object (%d) does not match with " "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) -- cgit v1.2.3