diff options
author | Davies Liu <davies@databricks.com> | 2015-04-16 17:33:57 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-04-16 17:33:57 -0700 |
commit | 6183b5e2caedd074073d0f6cb6609a634e2f5194 (patch) | |
tree | 073a82a2ff33eea0a5d8faae03e313ee749198b6 /python | |
parent | 5fe43433529346788e8c343d338a5b7dc169cf58 (diff) | |
download | spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.gz spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.bz2 spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.zip |
[SPARK-6911] [SQL] improve accessor for nested types
Support access columns by index in Python:
```
>>> df[df[0] > 3].collect()
[Row(age=5, name=u'Bob')]
```
Access items in ArrayType or MapType
```
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
>>> df.select(df.l[0], df.d["key"]).show()
```
Access field in StructType
```
>>> df.select(df.r.getField("b")).show()
>>> df.select(df.r.a).show()
```
Author: Davies Liu <davies@databricks.com>
Closes #5513 from davies/access and squashes the following commits:
e04d5a0 [Davies Liu] Update run-tests-jenkins
7ada9eb [Davies Liu] update timeout
d125ac4 [Davies Liu] check column name, improve scala tests
6b62540 [Davies Liu] fix test
db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access
6c32e79 [Davies Liu] add scala tests
11f1df3 [Davies Liu] improve accessor for nested types
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 49 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 18 |
2 files changed, 62 insertions, 5 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76504f986..b9a3e6cfe7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -563,16 +563,23 @@ class DataFrame(object): [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): + if item not in self.columns: + raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) - elif isinstance(item, list): + elif isinstance(item, (list, tuple)): return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) else: - raise IndexError("unexpected index: %s" % item) + raise TypeError("unexpected type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -580,8 +587,8 @@ class DataFrame(object): >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ - if name.startswith("__"): - raise AttributeError(name) + if name not in self.columns: + raise AttributeError("No such column: %s" % name) jc = self._jdf.apply(name) return Column(jc) @@ -1093,7 +1100,39 @@ class Column(object): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + l[0] d[key] + 1 value + >>> df.select(df.l[0], df.d["key"]).show() + l[0] d[key] + 1 value + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + r.b + b + >>> df.select(df.r.a).show() + r.a + 1 + """ + return Column(self._jc.getField(name)) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) # string methods rlike = _bin_op("rlike") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7c09a0cfe3..6691e8c8dc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,24 @@ class SQLTests(ReusedPySparkTestCase): pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() |