diff options
author | Davies Liu <davies@databricks.com> | 2015-02-04 15:55:09 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-02-04 15:55:09 -0800 |
commit | dc101b0e4e23dffddbc2f70d14a19fae5d87a328 (patch) | |
tree | e436271c351a64caa4727661cd6143ba6e415fa6 /python/pyspark/rdd.py | |
parent | e0490e271d078aa55d7c7583e2ba80337ed1b0c4 (diff) | |
download | spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.gz spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.bz2 spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.zip |
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <davies@databricks.com>
Closes #4351 from davies/python_udf and squashes the following commits:
d250692 [Davies Liu] fix conflict
34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
440f769 [Davies Liu] address comments
f0a3121 [Davies Liu] track life cycle of broadcast
f99b2e1 [Davies Liu] address comments
462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
7bccc3b [Davies Liu] python udf
58dee20 [Davies Liu] clean up
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2f8a0edfe9..6e029bf7f1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2162,6 +2162,25 @@ class RDD(object): yield row +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # tracking the life cycle by obj + if obj is not None: + obj._broadcast = broadcast + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + + class PipelinedRDD(RDD): """ @@ -2228,25 +2247,12 @@ class PipelinedRDD(RDD): command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() if profiler: |