aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/tests.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 047d857830..37a128907b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,7 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
from pyspark import shuffle
_have_scipy = False
@@ -690,10 +690,20 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(row[0], 5)
def test_udf2(self):
- self.sqlCtx.registerFunction("strlen", lambda string: len(string))
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
- self.assertEqual(u"4", res[0])
+ self.assertEqual(4, res[0])
+
+ def test_udf_with_array_type(self):
+ d = [Row(l=range(3), d={"key": range(5)})]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(range(3), l1)
+ self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}