diff options
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 24 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 7 |
2 files changed, 19 insertions, 12 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cee804f5cc..a9697999e8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1275,7 +1275,7 @@ class Column(object): # container operators __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") + __getitem__ = _bin_op("apply") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") @@ -1308,19 +1308,19 @@ class Column(object): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - +---+ - |r.b| - +---+ - | b| - +---+ + +----+ + |r[b]| + +----+ + | b| + +----+ >>> df.select(df.r.a).show() - +---+ - |r.a| - +---+ - | 1| - +---+ + +----+ + |r[a]| + +----+ + | 1| + +----+ """ - return Column(self._jc.getField(name)) + return self[name] def __getattr__(self, item): if item.startswith("__"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 45dfedce22..7e63f4d646 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -519,6 +519,13 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual("v", df.select(df.d["k"]).first()[0]) self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_field_accessor(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() |