From a60f91284ceee64de13f04559ec19c13a820a133 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Feb 2016 12:44:54 -0800 Subject: [SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code ## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan Closes #11342 from cloud-fan/python-clean. --- python/pyspark/rdd.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4eaf589ad5..37574cea0b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2309,7 +2309,7 @@ class RDD(object): yield row -def _prepare_for_python_RDD(sc, command, obj=None): +def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ @@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) self._jrdd_val = python_rdd.asJavaRDD() if profiler: -- cgit v1.2.3