diff options
author | ksonj <kson@siberie.de> | 2015-04-01 17:23:57 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-04-01 17:24:21 -0700 |
commit | 757b2e91756ba49d7d1ab89abf19b00c7f5fd721 (patch) | |
tree | 1d5f80905278c7f2e3de5fc5c789fd9080588cec | |
parent | 86b43993517104e6d5ad0785704ceec6db8acc20 (diff) | |
download | spark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.tar.gz spark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.tar.bz2 spark-757b2e91756ba49d7d1ab89abf19b00c7f5fd721.zip |
[SPARK-6553] [pyspark] Support functools.partial as UDF
Use `f.__repr__()` instead of `f.__name__` when instantiating `UserDefinedFunction`s, so `functools.partial`s may be used.
Author: ksonj <kson@siberie.de>
Closes #5206 from ksonj/partials and squashes the following commits:
ea66f3d [ksonj] Inserted blank lines for PEP8 compliance
d81b02b [ksonj] added tests for udf with partial function and callable object
2c76100 [ksonj] Makes UDFs work with all types of callables
b814a12 [ksonj] support functools.partial as udf
(cherry picked from commit 98f72dfc17853b570d05c20e97c78919682b6df6)
Signed-off-by: Josh Rosen <joshrosen@databricks.com>
-rw-r--r-- | python/pyspark/sql/functions.py | 3 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 31 |
2 files changed, 33 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8a478fddf0..146ba6f3e0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -123,7 +123,8 @@ class UserDefinedFunction(object): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 258464b7f2..b3a6a2c6a9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import pydoc import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() |