diff options
author | Bouke van der Bijl <boukevanderbijl@gmail.com> | 2014-05-10 13:02:13 -0700 |
---|---|---|
committer | Patrick Wendell <pwendell@gmail.com> | 2014-05-10 13:02:13 -0700 |
commit | 3776f2f283842543ff766398292532c6e94221cc (patch) | |
tree | c42e92390922359f8b3fec88ad5b371014900e40 | |
parent | c05d11bb307eaba40c5669da2d374c28debaa55a (diff) | |
download | spark-3776f2f283842543ff766398292532c6e94221cc.tar.gz spark-3776f2f283842543ff766398292532c6e94221cc.tar.bz2 spark-3776f2f283842543ff766398292532c6e94221cc.zip |
Add Python includes to path before depickling broadcast values
This fixes https://issues.apache.org/jira/browse/SPARK-1731 by adding the Python includes to the PYTHONPATH before depickling the broadcast values
@airhorns
Author: Bouke van der Bijl <boukevanderbijl@gmail.com>
Closes #656 from bouk/python-includes-before-broadcast and squashes the following commits:
7b0dfe4 [Bouke van der Bijl] Add Python includes to path before depickling broadcast values
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 10 | ||||
-rw-r--r-- | python/pyspark/worker.py | 14 |
2 files changed, 12 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fecd9762f3..388b838d78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -179,6 +179,11 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(split.index) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -186,11 +191,6 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(broadcast.value.length) dataOut.write(broadcast.value) } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } dataOut.flush() // Serialized command: dataOut.writeInt(command.length) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4c214ef359..f43210c6c0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -56,13 +56,6 @@ def main(infile, outfile): SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True - # fetch names and values of broadcast variables - num_broadcast_variables = read_int(infile) - for _ in range(num_broadcast_variables): - bid = read_long(infile) - value = pickleSer._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) @@ -70,6 +63,13 @@ def main(infile, outfile): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) + # fetch names and values of broadcast variables + num_broadcast_variables = read_int(infile) + for _ in range(num_broadcast_variables): + bid = read_long(infile) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() |