aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-26 16:04:44 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-26 16:04:44 -0700
commitd41d6c48207159490c1e1d9cc54015725cfa41b2 (patch)
tree3ae2758f38cb0079b5ef1c145a66bdd98653acda /python/pyspark/sql
parent086d4681df3ebfccfc04188262c10482f44553b0 (diff)
downloadspark-d41d6c48207159490c1e1d9cc54015725cfa41b2.tar.gz
spark-d41d6c48207159490c1e1d9cc54015725cfa41b2.tar.bz2
spark-d41d6c48207159490c1e1d9cc54015725cfa41b2.zip
[SPARK-10305] [SQL] fix create DataFrame from Python class
cc jkbradley Author: Davies Liu <davies@databricks.com> Closes #8470 from davies/fix_create_df.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/tests.py12
-rw-r--r--python/pyspark/sql/types.py6
2 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aacfb34c77..cd32e26c64 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint):
__UDT__ = PythonOnlyUDT()
+class MyObject(object):
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
@@ -383,6 +389,12 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sqlCtx.inferSchema(rdd)
self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+ def test_create_dataframe_from_objects(self):
+ data = [MyObject(1, "1"), MyObject(2, "2")]
+ df = self.sqlCtx.createDataFrame(data)
+ self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
+ self.assertEqual(df.first(), Row(key=1, value="1"))
+
def test_select_null_literal(self):
df = self.sqlCtx.sql("select null as col")
self.assertEquals(Row(col=None), df.first())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ed4e5b594b..94e581a783 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -537,6 +537,9 @@ class StructType(DataType):
return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields))
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)
else:
@@ -544,6 +547,9 @@ class StructType(DataType):
return tuple(obj.get(n) for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(d.get(n) for n in self.names)
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)