aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala64
-rw-r--r--python/pyspark/daemon.py24
-rw-r--r--python/pyspark/tests.py51
5 files changed, 125 insertions, 28 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 92c809d854..0bce531aab 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
+import java.net.Socket
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -102,10 +103,10 @@ class SparkEnv (
}
private[spark]
- def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
+ def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
- pythonWorkers(key).stop()
+ pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fe9a9e50ef..0b5322c6fb 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
- val worker: Socket = env.createPythonWorker(pythonExec,
- envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
+ envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
+ val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
- env.destroyPythonWorker(pythonExec, envVars.toMap)
+ env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- * This function is outdated, PySpark does not use it anymore
*/
- @deprecated
+ @deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 15fe8a9be6..7af260d0b7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.python
-import java.io.{DataInputStream, InputStream, OutputStreamWriter}
+import java.lang.Runtime
+import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
+import scala.collection.mutable
import scala.collection.JavaConversions._
import org.apache.spark._
@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
+ var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+
+ var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
+
+ def createSocket(): Socket = {
+ val socket = new Socket(daemonHost, daemonPort)
+ val pid = new DataInputStream(socket.getInputStream).readInt()
+ if (pid < 0) {
+ throw new IllegalStateException("Python daemon failed to launch worker")
+ }
+ daemonWorkers.put(socket, pid)
+ socket
+ }
+
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
- val socket = new Socket(daemonHost, daemonPort)
- val launchStatus = new DataInputStream(socket.getInputStream).readInt()
- if (launchStatus != 0) {
- throw new IllegalStateException("Python daemon failed to launch worker")
- }
- socket
+ createSocket()
} catch {
case exc: SocketException =>
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
- new Socket(daemonHost, daemonPort)
+ createSocket()
}
}
}
@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
- return serverSocket.accept()
+ val socket = serverSocket.accept()
+ simpleWorkers.put(socket, worker)
+ return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
private def stopDaemon() {
synchronized {
- // Request shutdown of existing daemon by sending SIGTERM
- if (daemon != null) {
- daemon.destroy()
- }
+ if (useDaemon) {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
- daemon = null
- daemonPort = 0
+ daemon = null
+ daemonPort = 0
+ } else {
+ simpleWorkers.mapValues(_.destroy())
+ }
}
}
def stop() {
stopDaemon()
}
+
+ def stopWorker(worker: Socket) {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
+ }
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
+ }
+ worker.close()
+ }
}
private object PythonWorkerFactory {
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 9fde0dde0f..b00da833d0 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ from errno import EINTR, ECHILD
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
-from pyspark.serializers import write_int
+from pyspark.serializers import read_int, write_int
def compute_real_exit_code(exit_code):
@@ -67,7 +67,8 @@ def worker(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
- write_int(0, outfile) # Acknowledge that the fork was successful
+ # Acknowledge that the fork was successful
+ write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
@@ -125,14 +126,23 @@ def manager():
else:
raise
if 0 in ready_fds:
- # Spark told us to exit by closing stdin
- shutdown(0)
+ try:
+ worker_pid = read_int(sys.stdin)
+ except EOFError:
+ # Spark told us to exit by closing stdin
+ shutdown(0)
+ try:
+ os.kill(worker_pid, signal.SIGKILL)
+ except OSError:
+ pass # process already died
+
+
if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
- fork_return_code = os.fork()
- if fork_return_code == 0:
+ pid = os.fork()
+ if pid == 0:
listen_sock.close()
try:
worker(sock)
@@ -143,11 +153,13 @@ def manager():
os._exit(0)
else:
sock.close()
+
except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
+ outfile.close()
sock.close()
finally:
shutdown(1)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 16fb5a9256..acc3c30371 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+class TestWorker(PySparkTestCase):
+ def test_cancel_task(self):
+ temp = tempfile.NamedTemporaryFile(delete=True)
+ temp.close()
+ path = temp.name
+ def sleep(x):
+ import os, time
+ with open(path, 'w') as f:
+ f.write("%d %d" % (os.getppid(), os.getpid()))
+ time.sleep(100)
+
+ # start job in background thread
+ def run():
+ self.sc.parallelize(range(1)).foreach(sleep)
+ import threading
+ t = threading.Thread(target=run)
+ t.daemon = True
+ t.start()
+
+ daemon_pid, worker_pid = 0, 0
+ while True:
+ if os.path.exists(path):
+ data = open(path).read().split(' ')
+ daemon_pid, worker_pid = map(int, data)
+ break
+ time.sleep(0.1)
+
+ # cancel jobs
+ self.sc.cancelAllJobs()
+ t.join()
+
+ for i in range(50):
+ try:
+ os.kill(worker_pid, 0)
+ time.sleep(0.1)
+ except OSError:
+ break # worker was killed
+ else:
+ self.fail("worker has not been killed after 5 seconds")
+
+ try:
+ os.kill(daemon_pid, 0)
+ except OSError:
+ self.fail("daemon had been killed")
+
+ def test_fd_leak(self):
+ N = 1100 # fd limit is 1024 by default
+ rdd = self.sc.parallelize(range(N), N)
+ self.assertEquals(N, rdd.count())
+
+
class TestSparkSubmit(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()