aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 258464b7f2..b3a6a2c6a9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -25,6 +25,7 @@ import pydoc
import shutil
import tempfile
import pickle
+import functools
import py4j
@@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
+from pyspark.sql.functions import UserDefinedFunction
class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_udf_with_callable(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ class PlusFour:
+ def __call__(self, col):
+ if col is not None:
+ return col + 4
+
+ call = PlusFour()
+ pudf = UserDefinedFunction(call, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+ def test_udf_with_partial_function(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ def some_func(col, param):
+ if col is not None:
+ return col + param
+
+ pfunc = functools.partial(some_func, param=4)
+ pudf = UserDefinedFunction(pfunc, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()