aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-24 16:41:23 -0800
committerPatrick Wendell <pwendell@gmail.com>2014-11-24 16:41:32 -0800
commit8371bc20821c39ee6d8116a867577e5c0fcd08ab (patch)
tree9e3a7f05cc1a8ceef18a49b9da07a5456fdf752b /python
parent2acbd2884f73c4503d753bb96e0acf75cd237536 (diff)
downloadspark-8371bc20821c39ee6d8116a867577e5c0fcd08ab.tar.gz
spark-8371bc20821c39ee6d8116a867577e5c0fcd08ab.tar.bz2
spark-8371bc20821c39ee6d8116a867577e5c0fcd08ab.zip
[SPARK-4578] fix asDict() with nested Row()
The Row object is created on the fly once the field is accessed, so we should access them by getattr() in asDict(0 Author: Davies Liu <davies@databricks.com> Closes #3434 from davies/fix_asDict and squashes the following commits: b20f1e7 [Davies Liu] fix asDict() with nested Row() (cherry picked from commit 050616b408c60eae02256913ceb645912dbff62e) Signed-off-by: Patrick Wendell <pwendell@gmail.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py2
-rw-r--r--python/pyspark/tests.py7
2 files changed, 5 insertions, 4 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index abb284d1e3..ae288471b0 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1178,7 +1178,7 @@ def _create_cls(dataType):
def asDict(self):
""" Return as a dict """
- return dict(zip(self.__FIELDS__, self))
+ return dict((n, getattr(self, n)) for n in self.__FIELDS__)
def __repr__(self):
# call collect __repr__ for nested objects
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a01bd8d415..29bcd38908 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -803,7 +803,7 @@ class SQLTests(ReusedPySparkTestCase):
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
- shutil.rmtree(cls.tempdir.name)
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
@@ -930,8 +930,9 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
srdd = self.sqlCtx.inferSchema(rdd)
srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l[0].a AS la from test").first()
- self.assertEqual(1, row.asDict()["la"])
+ row = self.sqlCtx.sql("select l, d from test").first()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
def test_infer_schema_with_udt(self):
from pyspark.tests import ExamplePoint, ExamplePointUDT