aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()