aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-14 14:09:46 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-14 14:09:46 -0700
commit1150a19b188a075166899fdb1e107b2ba1e505d8 (patch)
treeb5b45d3285002e3b276d47ac5d5b40c0b11f4ff8 /python
parent2a6590e510aba3bfc6603d280023128b3f5ac702 (diff)
downloadspark-1150a19b188a075166899fdb1e107b2ba1e505d8.tar.gz
spark-1150a19b188a075166899fdb1e107b2ba1e505d8.tar.bz2
spark-1150a19b188a075166899fdb1e107b2ba1e505d8.zip
[SPARK-8670] [SQL] Nested columns can't be referenced in pyspark
This bug is caused by a wrong column-exist-check in `__getitem__` of pyspark dataframe. `DataFrame.apply` accepts not only top level column names, but also nested column name like `a.b`, so we should remove that check from `__getitem__`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8202 from cloud-fan/nested.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--python/pyspark/sql/tests.py4
2 files changed, 3 insertions, 3 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 09647ff6d0..da742d7ce7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -722,8 +722,6 @@ class DataFrame(object):
[Row(age=5, name=u'Bob')]
"""
if isinstance(item, basestring):
- if item not in self.columns:
- raise IndexError("no such column: %s" % item)
jc = self._jdf.apply(item)
return Column(jc)
elif isinstance(item, Column):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 9b748101b5..13cf647b66 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -770,7 +770,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(isinstance(df['key'], Column))
self.assertTrue(isinstance(df[0], Column))
self.assertRaises(IndexError, lambda: df[2])
- self.assertRaises(IndexError, lambda: df["bad_key"])
+ self.assertRaises(AnalysisException, lambda: df["bad_key"])
self.assertRaises(TypeError, lambda: df[{}])
def test_column_name_with_non_ascii(self):
@@ -794,7 +794,9 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
self.assertEqual(1, df.select(df.l[0]).first()[0])
self.assertEqual(1, df.select(df.r["a"]).first()[0])
+ self.assertEqual(1, df.select(df["r.a"]).first()[0])
self.assertEqual("b", df.select(df.r["b"]).first()[0])
+ self.assertEqual("b", df.select(df["r.b"]).first()[0])
self.assertEqual("v", df.select(df.d["k"]).first()[0])
def test_infer_long_type(self):