aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-16 17:33:57 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-16 17:33:57 -0700
commit6183b5e2caedd074073d0f6cb6609a634e2f5194 (patch)
tree073a82a2ff33eea0a5d8faae03e313ee749198b6 /python
parent5fe43433529346788e8c343d338a5b7dc169cf58 (diff)
downloadspark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.gz
spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.tar.bz2
spark-6183b5e2caedd074073d0f6cb6609a634e2f5194.zip
[SPARK-6911] [SQL] improve accessor for nested types
Support access columns by index in Python: ``` >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] ``` Access items in ArrayType or MapType ``` >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() >>> df.select(df.l[0], df.d["key"]).show() ``` Access field in StructType ``` >>> df.select(df.r.getField("b")).show() >>> df.select(df.r.a).show() ``` Author: Davies Liu <davies@databricks.com> Closes #5513 from davies/access and squashes the following commits: e04d5a0 [Davies Liu] Update run-tests-jenkins 7ada9eb [Davies Liu] update timeout d125ac4 [Davies Liu] check column name, improve scala tests 6b62540 [Davies Liu] fix test db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access 6c32e79 [Davies Liu] add scala tests 11f1df3 [Davies Liu] improve accessor for nested types
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py49
-rw-r--r--python/pyspark/sql/tests.py18
2 files changed, 62 insertions, 5 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d76504f986..b9a3e6cfe7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -563,16 +563,23 @@ class DataFrame(object):
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df[ df.age > 3 ].collect()
[Row(age=5, name=u'Bob')]
+ >>> df[df[0] > 3].collect()
+ [Row(age=5, name=u'Bob')]
"""
if isinstance(item, basestring):
+ if item not in self.columns:
+ raise IndexError("no such column: %s" % item)
jc = self._jdf.apply(item)
return Column(jc)
elif isinstance(item, Column):
return self.filter(item)
- elif isinstance(item, list):
+ elif isinstance(item, (list, tuple)):
return self.select(*item)
+ elif isinstance(item, int):
+ jc = self._jdf.apply(self.columns[item])
+ return Column(jc)
else:
- raise IndexError("unexpected index: %s" % item)
+ raise TypeError("unexpected type: %s" % type(item))
def __getattr__(self, name):
"""Returns the :class:`Column` denoted by ``name``.
@@ -580,8 +587,8 @@ class DataFrame(object):
>>> df.select(df.age).collect()
[Row(age=2), Row(age=5)]
"""
- if name.startswith("__"):
- raise AttributeError(name)
+ if name not in self.columns:
+ raise AttributeError("No such column: %s" % name)
jc = self._jdf.apply(name)
return Column(jc)
@@ -1093,7 +1100,39 @@ class Column(object):
# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")
- getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
+
+ def getItem(self, key):
+ """An expression that gets an item at position `ordinal` out of a list,
+ or gets an item by key out of a dict.
+
+ >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
+ >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
+ l[0] d[key]
+ 1 value
+ >>> df.select(df.l[0], df.d["key"]).show()
+ l[0] d[key]
+ 1 value
+ """
+ return self[key]
+
+ def getField(self, name):
+ """An expression that gets a field by name in a StructField.
+
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
+ >>> df.select(df.r.getField("b")).show()
+ r.b
+ b
+ >>> df.select(df.r.a).show()
+ r.a
+ 1
+ """
+ return Column(self._jc.getField(name))
+
+ def __getattr__(self, item):
+ if item.startswith("__"):
+ raise AttributeError(item)
+ return self.getField(item)
# string methods
rlike = _bin_op("rlike")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7c09a0cfe3..6691e8c8dc 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -426,6 +426,24 @@ class SQLTests(ReusedPySparkTestCase):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
+ def test_access_column(self):
+ df = self.df
+ self.assertTrue(isinstance(df.key, Column))
+ self.assertTrue(isinstance(df['key'], Column))
+ self.assertTrue(isinstance(df[0], Column))
+ self.assertRaises(IndexError, lambda: df[2])
+ self.assertRaises(IndexError, lambda: df["bad_key"])
+ self.assertRaises(TypeError, lambda: df[{}])
+
+ def test_access_nested_types(self):
+ df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+ self.assertEqual(1, df.select(df.l[0]).first()[0])
+ self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
+ self.assertEqual(1, df.select(df.r.a).first()[0])
+ self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
+ self.assertEqual("v", df.select(df.d["k"]).first()[0])
+ self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()