aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1922d03af6..d37c5dbed7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_explode(self):
+ from pyspark.sql.functions import explode
+ d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ result = data.select(explode(data.intlist).alias("a")).select("a").collect()
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[2][0], 3)
+
+ result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
+ self.assertEqual(result[0][0], "a")
+ self.assertEqual(result[0][1], "b")
+
def test_udf_with_callable(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)