aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala10
-rw-r--r--python/pyspark/worker.py14
2 files changed, 12 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fecd9762f3..388b838d78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -179,6 +179,11 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
+ // Python includes (*.zip and *.egg files)
+ dataOut.writeInt(pythonIncludes.length)
+ for (include <- pythonIncludes) {
+ PythonRDD.writeUTF(include, dataOut)
+ }
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -186,11 +191,6 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
- // Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.length)
- for (include <- pythonIncludes) {
- PythonRDD.writeUTF(include, dataOut)
- }
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4c214ef359..f43210c6c0 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -56,13 +56,6 @@ def main(infile, outfile):
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
- # fetch names and values of broadcast variables
- num_broadcast_variables = read_int(infile)
- for _ in range(num_broadcast_variables):
- bid = read_long(infile)
- value = pickleSer._read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, value)
-
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
@@ -70,6 +63,13 @@ def main(infile, outfile):
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
+ # fetch names and values of broadcast variables
+ num_broadcast_variables = read_int(infile)
+ for _ in range(num_broadcast_variables):
+ bid = read_long(infile)
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
+
command = pickleSer._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()