From a60f91284ceee64de13f04559ec19c13a820a133 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Feb 2016 12:44:54 -0800 Subject: [SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code ## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan Closes #11342 from cloud-fan/python-clean. --- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/functions.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) (limited to 'python/pyspark/sql') diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89bf1443a6..87e32c04ea 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,7 +29,7 @@ else: from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6894c27338..b30cc6799e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import _wrap_function, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1645,16 +1645,14 @@ class UserDefinedFunction(object): f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) sc = SparkContext.getOrCreate() - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + wrapped_func = _wrap_function(sc, func, ser, ser) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, - broadcast_vars, sc._javaAccumulator, jdt) + name, wrapped_func, jdt) return judf def __del__(self): -- cgit v1.2.3