diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/__init__.py | 19 | ||||
-rw-r--r-- | python/pyspark/java_gateway.py | 17 | ||||
-rw-r--r-- | python/pyspark/worker.py | 11 |
3 files changed, 37 insertions, 10 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 3e8bca62f0..fd5972d381 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -1,5 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + """ -PySpark is a Python API for Spark. +PySpark is the Python API for Spark. Public classes: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 26fbe0f080..e615c1e9b6 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -18,6 +18,7 @@ import os import sys import signal +import platform from subprocess import Popen, PIPE from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient @@ -29,12 +30,18 @@ SPARK_HOME = os.environ["SPARK_HOME"] def launch_gateway(): # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and SPARK_MEM settings from spark-env.sh - command = [os.path.join(SPARK_HOME, "spark-class"), "py4j.GatewayServer", + on_windows = platform.system() == "Windows" + script = "spark-class.cmd" if on_windows else "spark-class" + command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer", "--die-on-broken-pipe", "0"] - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_function(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_function) + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdout=PIPE, stdin=PIPE) # Determine which ephemeral port the server started on: port = int(proc.stdout.readline()) # Create a thread to echo output from the GatewayServer, which is required diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 695f6dfb84..d63c2aaef7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,6 +21,7 @@ Worker that receives input from Piped RDD. import os import sys import time +import socket import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -94,7 +95,9 @@ def main(infile, outfile): if __name__ == '__main__': - # Redirect stdout to stderr so that users must return values from functions. - old_stdout = os.fdopen(os.dup(1), 'w') - os.dup2(2, 1) - main(sys.stdin, old_stdout) + # Read a local port to connect to from stdin + java_port = int(sys.stdin.readline()) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("127.0.0.1", java_port)) + sock_file = sock.makefile("a+", 65536) + main(sock_file, sock_file) |