aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2f8a0edfe9..6e029bf7f1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2162,6 +2162,25 @@ class RDD(object):
yield row
+def _prepare_for_python_RDD(sc, command, obj=None):
+ # the serialized command will be compressed by broadcast
+ ser = CloudPickleSerializer()
+ pickled_command = ser.dumps(command)
+ if len(pickled_command) > (1 << 20): # 1M
+ broadcast = sc.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
+ # tracking the life cycle by obj
+ if obj is not None:
+ obj._broadcast = broadcast
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in sc._pickled_broadcast_vars],
+ sc._gateway._gateway_client)
+ sc._pickled_broadcast_vars.clear()
+ env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
+ includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
+ return pickled_command, broadcast_vars, env, includes
+
+
class PipelinedRDD(RDD):
"""
@@ -2228,25 +2247,12 @@ class PipelinedRDD(RDD):
command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
- # the serialized command will be compressed by broadcast
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- self._broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(self._broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx._gateway._gateway_client)
- self.ctx._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self.ctx.environment,
- self.ctx._gateway._gateway_client)
- includes = ListConverter().convert(self.ctx._python_includes,
- self.ctx._gateway._gateway_client)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_command),
+ bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator)
+ bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
if profiler: