aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-28 19:38:16 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-28 19:38:16 -0700
commit8c0bfd08fc19fa5de7d77bf8306d19834f907ec0 (patch)
tree96d0424f06e2c20d9ee34cc482f792fdbff473a6 /python
parentb5e79bf889700159d490cdac1f6322dff424b1d9 (diff)
downloadspark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.gz
spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.bz2
spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.zip
[SPARK-4133] [SQL] [PySpark] type conversionfor python udf
Call Python UDF on ArrayType/MapType/PrimitiveType, the returnType can also be ArrayType/MapType/PrimitiveType. For StructType, it will act as tuple (without attributes). If returnType is StructType, it also should be tuple. Author: Davies Liu <davies@databricks.com> Closes #2973 from davies/udf_array and squashes the following commits: 306956e [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 2c00e43 [Davies Liu] fix merge 11395fa [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 9df50a2 [Davies Liu] address comments 79afb4e [Davies Liu] type conversionfor python udf
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/tests.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 047d857830..37a128907b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,7 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
from pyspark import shuffle
_have_scipy = False
@@ -690,10 +690,20 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(row[0], 5)
def test_udf2(self):
- self.sqlCtx.registerFunction("strlen", lambda string: len(string))
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
- self.assertEqual(u"4", res[0])
+ self.assertEqual(4, res[0])
+
+ def test_udf_with_array_type(self):
+ d = [Row(l=range(3), d={"key": range(5)})]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(range(3), l1)
+ self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}