aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7c09a0cfe3..6691e8c8dc 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -426,6 +426,24 @@ class SQLTests(ReusedPySparkTestCase):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
+ def test_access_column(self):
+ df = self.df
+ self.assertTrue(isinstance(df.key, Column))
+ self.assertTrue(isinstance(df['key'], Column))
+ self.assertTrue(isinstance(df[0], Column))
+ self.assertRaises(IndexError, lambda: df[2])
+ self.assertRaises(IndexError, lambda: df["bad_key"])
+ self.assertRaises(TypeError, lambda: df[{}])
+
+ def test_access_nested_types(self):
+ df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+ self.assertEqual(1, df.select(df.l[0]).first()[0])
+ self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
+ self.assertEqual(1, df.select(df.r.a).first()[0])
+ self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
+ self.assertEqual("v", df.select(df.d["k"]).first()[0])
+ self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()