diff options
author | Davies Liu <davies@databricks.com> | 2014-10-28 19:38:16 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-10-28 19:38:16 -0700 |
commit | 8c0bfd08fc19fa5de7d77bf8306d19834f907ec0 (patch) | |
tree | 96d0424f06e2c20d9ee34cc482f792fdbff473a6 /python | |
parent | b5e79bf889700159d490cdac1f6322dff424b1d9 (diff) | |
download | spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.gz spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.bz2 spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.zip |
[SPARK-4133] [SQL] [PySpark] type conversionfor python udf
Call Python UDF on ArrayType/MapType/PrimitiveType, the returnType can also be ArrayType/MapType/PrimitiveType.
For StructType, it will act as tuple (without attributes). If returnType is StructType, it also should be tuple.
Author: Davies Liu <davies@databricks.com>
Closes #2973 from davies/udf_array and squashes the following commits:
306956e [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array
2c00e43 [Davies Liu] fix merge
11395fa [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array
9df50a2 [Davies Liu] address comments
79afb4e [Davies Liu] type conversionfor python udf
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/tests.py | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 047d857830..37a128907b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType from pyspark import shuffle _have_scipy = False @@ -690,10 +690,20 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(u"4", res[0]) + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} |