aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala107
-rw-r--r--python/pyspark/worker.py11
2 files changed, 106 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 08e3f670f5..67d45723ba 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,8 +17,8 @@
package org.apache.spark.api.python
-import java.io.{File, DataInputStream, IOException}
-import java.net.{Socket, SocketException, InetAddress}
+import java.io.{OutputStreamWriter, File, DataInputStream, IOException}
+import java.net.{ServerSocket, Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
@@ -26,11 +26,30 @@ import org.apache.spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
+
+ // Because forking processes from Java is expensive, we prefer to launch a single Python daemon
+ // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
+ // only works on UNIX-based systems now because it uses signals for child management, so we can
+ // also fall back to launching workers (pyspark/worker.py) directly.
+ val useDaemon = !System.getProperty("os.name").startsWith("Windows")
+
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
+ if (useDaemon) {
+ createThroughDaemon()
+ } else {
+ createSimpleWorker()
+ }
+ }
+
+ /**
+ * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
+ * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
+ */
+ private def createThroughDaemon(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
@@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
+ /**
+ * Launch a worker by executing worker.py directly and telling it to connect to us.
+ */
+ private def createSimpleWorker(): Socket = {
+ var serverSocket: ServerSocket = null
+ try {
+ serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+
+ // Create and start the worker
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
+ workerEnv.put("PYTHONPATH", pythonPath)
+ val worker = pb.start()
+
+ // Redirect the worker's stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ setDaemon(true)
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+ val in = worker.getErrorStream
+ val buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ // Redirect worker's stdout to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ setDaemon(true)
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+ val in = worker.getInputStream
+ val buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ // Tell the worker our port
+ val out = new OutputStreamWriter(worker.getOutputStream)
+ out.write(serverSocket.getLocalPort + "\n")
+ out.flush()
+
+ // Wait for it to connect to our socket
+ serverSocket.setSoTimeout(10000)
+ try {
+ return serverSocket.accept()
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Python worker did not connect back in time", e)
+ }
+ } finally {
+ if (serverSocket != null) {
+ serverSocket.close()
+ }
+ }
+ null
+ }
+
def stop() {
stopDaemon()
}
@@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
+ setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
- // FIXME HACK: We copy the stream on the level of bytes to
- // attempt to dodge encoding problems.
+ // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = daemon.getErrorStream
- var buf = new Array[Byte](1024)
+ val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
@@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
+ setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
- // FIXME HACK: We copy the stream on the level of bytes to
- // attempt to dodge encoding problems.
- var buf = new Array[Byte](1024)
+ // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+ val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 695f6dfb84..d63c2aaef7 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -21,6 +21,7 @@ Worker that receives input from Piped RDD.
import os
import sys
import time
+import socket
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -94,7 +95,9 @@ def main(infile, outfile):
if __name__ == '__main__':
- # Redirect stdout to stderr so that users must return values from functions.
- old_stdout = os.fdopen(os.dup(1), 'w')
- os.dup2(2, 1)
- main(sys.stdin, old_stdout)
+ # Read a local port to connect to from stdin
+ java_port = int(sys.stdin.readline())
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(("127.0.0.1", java_port))
+ sock_file = sock.makefile("a+", 65536)
+ main(sock_file, sock_file)