aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r--python/pyspark/sql/functions.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index fd5a3ba8ad..031745a1c4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -801,23 +801,24 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
- def __init__(self, func, returnType):
+ def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
self._broadcast = None
- self._judf = self._create_judf()
+ self._judf = self._create_judf(name)
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: map(lambda x: f(*x), it)
+ def _create_judf(self, name):
+ f, returnType = self.func, self.returnType # put them in closure `func`
+ func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
- fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
- judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+ if name is None:
+ name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf