diff options
author | Davies Liu <davies@databricks.com> | 2014-10-24 10:48:03 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2014-10-24 10:48:03 -0700 |
commit | d60a9d440b00beb107c1f1d7f42886c94f04a092 (patch) | |
tree | 1e65c3be4794876bf4c357410f0264921d67d27c /python | |
parent | d2987e8f7a2cb3bf971f381399d8efdccb51d3d2 (diff) | |
download | spark-d60a9d440b00beb107c1f1d7f42886c94f04a092.tar.gz spark-d60a9d440b00beb107c1f1d7f42886c94f04a092.tar.bz2 spark-d60a9d440b00beb107c1f1d7f42886c94f04a092.zip |
[SPARK-4051] [SQL] [PySpark] Convert Row into dictionary
Added a method to Row to turn row into dict:
```
>>> row = Row(a=1)
>>> row.asDict()
{'a': 1}
```
Author: Davies Liu <davies@databricks.com>
Closes #2896 from davies/dict and squashes the following commits:
8d97366 [Davies Liu] convert Row into dict
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql.py | 12 | ||||
-rw-r--r-- | python/pyspark/tests.py | 9 |
2 files changed, 21 insertions, 0 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b31a82f9b1..7daf306f68 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -883,6 +883,10 @@ def _create_cls(dataType): # create property for fast access locals().update(_create_properties(dataType.fields)) + def asDict(self): + """ Return as a dict """ + return dict(zip(self.__FIELDS__, self)) + def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) @@ -1466,6 +1470,14 @@ class Row(tuple): else: raise ValueError("No args or kwargs") + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + # let obect acs like class def __call__(self, *args): """create new Row object""" diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7a2107ec32..047d857830 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -771,6 +771,15 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + srdd.registerTempTable("test") + row = self.sqlCtx.sql("select l[0].a AS la from test").first() + self.assertEqual(1, row.asDict()["la"]) + class InputFormatTests(ReusedPySparkTestCase): |