aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py24
-rw-r--r--python/pyspark/sql/tests.py7
2 files changed, 19 insertions, 12 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cee804f5cc..a9697999e8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1275,7 +1275,7 @@ class Column(object):
# container operators
__contains__ = _bin_op("contains")
- __getitem__ = _bin_op("getItem")
+ __getitem__ = _bin_op("apply")
# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
@@ -1308,19 +1308,19 @@ class Column(object):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
- +---+
- |r.b|
- +---+
- | b|
- +---+
+ +----+
+ |r[b]|
+ +----+
+ | b|
+ +----+
>>> df.select(df.r.a).show()
- +---+
- |r.a|
- +---+
- | 1|
- +---+
+ +----+
+ |r[a]|
+ +----+
+ | 1|
+ +----+
"""
- return Column(self._jc.getField(name))
+ return self[name]
def __getattr__(self, item):
if item.startswith("__"):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 45dfedce22..7e63f4d646 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -519,6 +519,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual("v", df.select(df.d["k"]).first()[0])
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+ def test_field_accessor(self):
+ df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+ self.assertEqual(1, df.select(df.l[0]).first()[0])
+ self.assertEqual(1, df.select(df.r["a"]).first()[0])
+ self.assertEqual("b", df.select(df.r["b"]).first()[0])
+ self.assertEqual("v", df.select(df.d["k"]).first()[0])
+
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()