diff options
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 379bbfd4c2..d63c2aaef7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,9 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + """ Worker that receives input from Piped RDD. """ import os import sys import time +import socket import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -32,15 +50,26 @@ def main(infile, outfile): split_index = read_int(infile) if split_index == -1: # for unit tests return + + # fetch name of workdir spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True - sys.path.append(spark_files_dir) + + # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + + # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH + sys.path.append(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + + # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) if bypassSerializer: @@ -66,7 +95,9 @@ def main(infile, outfile): if __name__ == '__main__': - # Redirect stdout to stderr so that users must return values from functions. - old_stdout = os.fdopen(os.dup(1), 'w') - os.dup2(2, 1) - main(sys.stdin, old_stdout) + # Read a local port to connect to from stdin + java_port = int(sys.stdin.readline()) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("127.0.0.1", java_port)) + sock_file = sock.makefile("a+", 65536) + main(sock_file, sock_file) |