aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-26 16:04:44 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-26 16:04:53 -0700
commit0bdb800575ae2872e2655983a1be94dcf2e8c36b (patch)
treef61326b6f968bb047ff2c9352d75fb6b0282c724
parentefbd7af44e855efcbb1fa224e80db24947e2b153 (diff)
downloadspark-0bdb800575ae2872e2655983a1be94dcf2e8c36b.tar.gz
spark-0bdb800575ae2872e2655983a1be94dcf2e8c36b.tar.bz2
spark-0bdb800575ae2872e2655983a1be94dcf2e8c36b.zip
[SPARK-10305] [SQL] fix create DataFrame from Python class
cc jkbradley Author: Davies Liu <davies@databricks.com> Closes #8470 from davies/fix_create_df. (cherry picked from commit d41d6c48207159490c1e1d9cc54015725cfa41b2) Signed-off-by: Davies Liu <davies.liu@gmail.com>
-rw-r--r--python/pyspark/sql/tests.py12
-rw-r--r--python/pyspark/sql/types.py6
2 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aacfb34c77..cd32e26c64 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint):
__UDT__ = PythonOnlyUDT()
+class MyObject(object):
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
@@ -383,6 +389,12 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sqlCtx.inferSchema(rdd)
self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+ def test_create_dataframe_from_objects(self):
+ data = [MyObject(1, "1"), MyObject(2, "2")]
+ df = self.sqlCtx.createDataFrame(data)
+ self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
+ self.assertEqual(df.first(), Row(key=1, value="1"))
+
def test_select_null_literal(self):
df = self.sqlCtx.sql("select null as col")
self.assertEquals(Row(col=None), df.first())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ed4e5b594b..94e581a783 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -537,6 +537,9 @@ class StructType(DataType):
return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields))
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)
else:
@@ -544,6 +547,9 @@ class StructType(DataType):
return tuple(obj.get(n) for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(d.get(n) for n in self.names)
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)