aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorksonj <kson@siberie.de>2015-04-01 17:23:57 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-04-01 17:24:21 -0700
commit757b2e91756ba49d7d1ab89abf19b00c7f5fd721 (patch)
tree1d5f80905278c7f2e3de5fc5c789fd9080588cec /python/pyspark/sql/tests.py
parent86b43993517104e6d5ad0785704ceec6db8acc20 (diff)
downloadspark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.tar.gz
spark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.tar.bz2
spark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.zip
[SPARK-6553] [pyspark] Support functools.partial as UDF
Use `f.__repr__()` instead of `f.__name__` when instantiating `UserDefinedFunction`s, so `functools.partial`s may be used. Author: ksonj <kson@siberie.de> Closes #5206 from ksonj/partials and squashes the following commits: ea66f3d [ksonj] Inserted blank lines for PEP8 compliance d81b02b [ksonj] added tests for udf with partial function and callable object 2c76100 [ksonj] Makes UDFs work with all types of callables b814a12 [ksonj] support functools.partial as udf (cherry picked from commit 98f72dfc17853b570d05c20e97c78919682b6df6) Signed-off-by: Josh Rosen <joshrosen@databricks.com>
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 258464b7f2..b3a6a2c6a9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -25,6 +25,7 @@ import pydoc
import shutil
import tempfile
import pickle
+import functools
import py4j
@@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
+from pyspark.sql.functions import UserDefinedFunction
class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_udf_with_callable(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ class PlusFour:
+ def __call__(self, col):
+ if col is not None:
+ return col + 4
+
+ call = PlusFour()
+ pudf = UserDefinedFunction(call, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+ def test_udf_with_partial_function(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ def some_func(col, param):
+ if col is not None:
+ return col + param
+
+ pfunc = functools.partial(some_func, param=4)
+ pudf = UserDefinedFunction(pfunc, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()