diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af6..d37c5dbed7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) |