aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala1
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorker.scala4
-rw-r--r--python/pyspark/daemon.py46
-rw-r--r--python/pyspark/tests.py43
-rw-r--r--python/pyspark/worker.py2
5 files changed, 74 insertions, 22 deletions
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 5691e24c32..5b55d45212 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -44,6 +44,7 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]()
def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala
index 8ee3c6884f..74c8c6d37a 100644
--- a/core/src/main/scala/spark/api/python/PythonWorker.scala
+++ b/core/src/main/scala/spark/api/python/PythonWorker.scala
@@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin
}
}
+ def stop() {
+ stopDaemon
+ }
+
private def startDaemon() {
synchronized {
// Is it already running?
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 642f30b2b9..ab9c19df57 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -12,7 +12,7 @@ try:
except NotImplementedError:
POOLSIZE = 4
-should_exit = False
+should_exit = multiprocessing.Event()
def worker(listen_sock):
@@ -21,14 +21,13 @@ def worker(listen_sock):
# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(signum, frame):
- global should_exit
- should_exit = True
+ assert should_exit.is_set()
signal(SIGHUP, handle_sighup)
- while not should_exit:
+ while not should_exit.is_set():
# Wait until a client arrives or we have to exit
sock = None
- while not should_exit and sock is None:
+ while not should_exit.is_set() and sock is None:
try:
sock, addr = listen_sock.accept()
except EnvironmentError as err:
@@ -36,8 +35,8 @@ def worker(listen_sock):
raise
if sock is not None:
- # Fork a child to handle the client
- if os.fork() == 0:
+ # Fork to handle the client
+ if os.fork() != 0:
# Leave the worker pool
signal(SIGHUP, SIG_DFL)
listen_sock.close()
@@ -50,7 +49,7 @@ def worker(listen_sock):
else:
sock.close()
- assert should_exit
+ assert should_exit.is_set()
os._exit(0)
@@ -73,9 +72,7 @@ def manager():
listen_sock.close()
def shutdown():
- global should_exit
- os.kill(0, SIGHUP)
- should_exit = True
+ should_exit.set()
# Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown())
@@ -85,8 +82,8 @@ def manager():
def handle_sigchld(signum, frame):
try:
pid, status = os.waitpid(0, os.WNOHANG)
- if (pid, status) != (0, 0) and not should_exit:
- raise RuntimeError("pool member crashed: %s, %s" % (pid, status))
+ if status != 0 and not should_exit.is_set():
+ raise RuntimeError("worker crashed: %s, %s" % (pid, status))
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
@@ -94,15 +91,20 @@ def manager():
# Initialization complete
sys.stdout.close()
- while not should_exit:
- try:
- # Spark tells us to exit by closing stdin
- if sys.stdin.read() == '':
- shutdown()
- except EnvironmentError as err:
- if err.errno != EINTR:
- shutdown()
- raise
+ try:
+ while not should_exit.is_set():
+ try:
+ # Spark tells us to exit by closing stdin
+ if os.read(0, 512) == '':
+ shutdown()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ shutdown()
+ raise
+ finally:
+ should_exit.set()
+ # Send SIGHUP to notify workers of shutdown
+ os.kill(0, SIGHUP)
if __name__ == '__main__':
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 94d612ea6e..f76ee3c236 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish):
def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
+ if split_index == -1: # for unit tests
+ return
spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True