diff options
author | Davies Liu <davies@databricks.com> | 2015-07-20 12:14:47 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-20 12:14:47 -0700 |
commit | 9f913c4fd6f0f223fd378e453d5b9a87beda1ac4 (patch) | |
tree | b0e89b1b7f4c0617c7bab6b314b161fc5240c95b /python | |
parent | 02181fb6d14833448fb5c501045655213d3cf340 (diff) | |
download | spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.gz spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.bz2 spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.zip |
[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type
This PR also remove the duplicated code between registerFunction and UserDefinedFunction.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #7450 from davies/fix_return_type and squashes the following commits:
e80bf9f [Davies Liu] remove debugging code
f94b1f6 [Davies Liu] fix mima
8f9c58b [Davies Liu] convert returned object from UDF into internal type
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/context.py | 16 | ||||
-rw-r--r-- | python/pyspark/sql/functions.py | 15 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 4 |
3 files changed, 14 insertions, 21 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c93a15bada..abb6522dde 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,6 +34,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler +from pyspark.sql.functions import UserDefinedFunction try: import pandas @@ -191,19 +192,8 @@ class SQLContext(object): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(_c0=4)] """ - func = lambda _, it: map(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_cmd), - env, - includes, - self._sc.pythonExec, - self._sc.pythonVer, - bvars, - self._sc._javaAccumulator, - returnType.json()) + udf = UserDefinedFunction(f, returnType, name) + self._ssql_ctx.udf().registerPython(name, udf._judf) def _inferSchemaFromList(self, data): """ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd5a3ba8ad..031745a1c4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -801,23 +801,24 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ - def __init__(self, func, returnType): + def __init__(self, func, returnType, name=None): self.func = func self.returnType = returnType self._broadcast = None - self._judf = self._create_judf() + self._judf = self._create_judf(name) - def _create_judf(self): - f = self.func # put it in closure `func` - func = lambda _, it: map(lambda x: f(*x), it) + def _create_judf(self, name): + f, returnType = self.func, self.returnType # put them in closure `func` + func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) sc = SparkContext._active_spark_context pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, + if name is None: + name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator, jdt) return judf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7a55d801e4..ea821f486f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase): self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sc.parallelize([row]).toDF() self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint |