aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorksonj <kson@siberie.de>2015-04-01 17:23:57 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-04-01 17:23:57 -0700
commit98f72dfc17853b570d05c20e97c78919682b6df6 (patch)
tree98ab912e4d28f574bcd489fadb9b3a10e7f1dd01
parentbc04fa2e2ade3343f0fdc20cd9702260c717dea7 (diff)
downloadspark-98f72dfc17853b570d05c20e97c78919682b6df6.tar.gz
spark-98f72dfc17853b570d05c20e97c78919682b6df6.tar.bz2
spark-98f72dfc17853b570d05c20e97c78919682b6df6.zip
[SPARK-6553] [pyspark] Support functools.partial as UDF
Use `f.__repr__()` instead of `f.__name__` when instantiating `UserDefinedFunction`s, so `functools.partial`s may be used. Author: ksonj <kson@siberie.de> Closes #5206 from ksonj/partials and squashes the following commits: ea66f3d [ksonj] Inserted blank lines for PEP8 compliance d81b02b [ksonj] added tests for udf with partial function and callable object 2c76100 [ksonj] Makes UDFs work with all types of callables b814a12 [ksonj] support functools.partial as udf
-rw-r--r--python/pyspark/sql/functions.py3
-rw-r--r--python/pyspark/sql/tests.py31
2 files changed, 33 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8a478fddf0..146ba6f3e0 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -123,7 +123,8 @@ class UserDefinedFunction(object):
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+ judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
includes, sc.pythonExec, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 258464b7f2..b3a6a2c6a9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -25,6 +25,7 @@ import pydoc
import shutil
import tempfile
import pickle
+import functools
import py4j
@@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
+from pyspark.sql.functions import UserDefinedFunction
class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_udf_with_callable(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ class PlusFour:
+ def __call__(self, col):
+ if col is not None:
+ return col + 4
+
+ call = PlusFour()
+ pudf = UserDefinedFunction(call, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+ def test_udf_with_partial_function(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ def some_func(col, param):
+ if col is not None:
+ return col + param
+
+ pfunc = functools.partial(some_func, param=4)
+ pudf = UserDefinedFunction(pfunc, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()