aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-04 15:55:09 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 15:55:09 -0800
commitdc101b0e4e23dffddbc2f70d14a19fae5d87a328 (patch)
treee436271c351a64caa4727661cd6143ba6e415fa6 /python/pyspark/rdd.py
parente0490e271d078aa55d7c7583e2ba80337ed1b0c4 (diff)
downloadspark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.gz
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.bz2
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.zip
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <davies@databricks.com> Closes #4351 from davies/python_udf and squashes the following commits: d250692 [Davies Liu] fix conflict 34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 440f769 [Davies Liu] address comments f0a3121 [Davies Liu] track life cycle of broadcast f99b2e1 [Davies Liu] address comments 462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 7bccc3b [Davies Liu] python udf 58dee20 [Davies Liu] clean up
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2f8a0edfe9..6e029bf7f1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2162,6 +2162,25 @@ class RDD(object):
yield row
+def _prepare_for_python_RDD(sc, command, obj=None):
+ # the serialized command will be compressed by broadcast
+ ser = CloudPickleSerializer()
+ pickled_command = ser.dumps(command)
+ if len(pickled_command) > (1 << 20): # 1M
+ broadcast = sc.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
+ # tracking the life cycle by obj
+ if obj is not None:
+ obj._broadcast = broadcast
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in sc._pickled_broadcast_vars],
+ sc._gateway._gateway_client)
+ sc._pickled_broadcast_vars.clear()
+ env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
+ includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
+ return pickled_command, broadcast_vars, env, includes
+
+
class PipelinedRDD(RDD):
"""
@@ -2228,25 +2247,12 @@ class PipelinedRDD(RDD):
command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
- # the serialized command will be compressed by broadcast
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- self._broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(self._broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx._gateway._gateway_client)
- self.ctx._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self.ctx.environment,
- self.ctx._gateway._gateway_client)
- includes = ListConverter().convert(self.ctx._python_includes,
- self.ctx._gateway._gateway_client)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_command),
+ bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator)
+ bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
if profiler: