aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-02-24 12:44:54 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-24 12:44:54 -0800
commita60f91284ceee64de13f04559ec19c13a820a133 (patch)
tree68d7d84620835d5e66cc3f94771a11655c4cbe2b /python/pyspark/rdd.py
parentf92f53faeea020d80638a06752d69ca7a949cdeb (diff)
downloadspark-a60f91284ceee64de13f04559ec19c13a820a133.tar.gz
spark-a60f91284ceee64de13f04559ec19c13a820a133.tar.bz2
spark-a60f91284ceee64de13f04559ec19c13a820a133.zip
[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code
## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11342 from cloud-fan/python-clean.
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4eaf589ad5..37574cea0b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2309,7 +2309,7 @@ class RDD(object):
yield row
-def _prepare_for_python_RDD(sc, command, obj=None):
+def _prepare_for_python_RDD(sc, command):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
@@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
return pickled_command, broadcast_vars, env, includes
+def _wrap_function(sc, func, deserializer, serializer, profiler=None):
+ assert deserializer, "deserializer should not be empty"
+ assert serializer, "serializer should not be empty"
+ command = (func, profiler, deserializer, serializer)
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
class PipelinedRDD(RDD):
"""
@@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD):
else:
profiler = None
- command = (self.func, profiler, self._prev_jrdd_deserializer,
- self._jrdd_deserializer)
- pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
- python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_cmd),
- env, includes, self.preservesPartitioning,
- self.ctx.pythonExec, self.ctx.pythonVer,
- bvars, self.ctx._javaAccumulator)
+ wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
+ self._jrdd_deserializer, profiler)
+ python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
+ self.preservesPartitioning)
self._jrdd_val = python_rdd.asJavaRDD()
if profiler: