aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-03 15:52:00 -0700
committerJosh Rosen <joshrosen@apache.org>2014-08-03 15:52:00 -0700
commit55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9 (patch)
tree91277ef5bfed5d8d3177679ebfe52186350f430b /python/pyspark/tests.py
parente139e2be60ef23281327744e1b3e74904dfdf63f (diff)
downloadspark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.tar.gz
spark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.tar.bz2
spark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.zip
[SPARK-1740] [PySpark] kill the python worker
Kill only the python worker related to cancelled tasks. The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker. When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon. Author: Davies Liu <davies.liu@gmail.com> Closes #1643 from davies/kill and squashes the following commits: 8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy 46ca150 [Davies Liu] address comment acd751c [Davies Liu] kill the worker when task is canceled
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 16fb5a9256..acc3c30371 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+class TestWorker(PySparkTestCase):
+ def test_cancel_task(self):
+ temp = tempfile.NamedTemporaryFile(delete=True)
+ temp.close()
+ path = temp.name
+ def sleep(x):
+ import os, time
+ with open(path, 'w') as f:
+ f.write("%d %d" % (os.getppid(), os.getpid()))
+ time.sleep(100)
+
+ # start job in background thread
+ def run():
+ self.sc.parallelize(range(1)).foreach(sleep)
+ import threading
+ t = threading.Thread(target=run)
+ t.daemon = True
+ t.start()
+
+ daemon_pid, worker_pid = 0, 0
+ while True:
+ if os.path.exists(path):
+ data = open(path).read().split(' ')
+ daemon_pid, worker_pid = map(int, data)
+ break
+ time.sleep(0.1)
+
+ # cancel jobs
+ self.sc.cancelAllJobs()
+ t.join()
+
+ for i in range(50):
+ try:
+ os.kill(worker_pid, 0)
+ time.sleep(0.1)
+ except OSError:
+ break # worker was killed
+ else:
+ self.fail("worker has not been killed after 5 seconds")
+
+ try:
+ os.kill(daemon_pid, 0)
+ except OSError:
+ self.fail("daemon had been killed")
+
+ def test_fd_leak(self):
+ N = 1100 # fd limit is 1024 by default
+ rdd = self.sc.parallelize(range(N), N)
+ self.assertEquals(N, rdd.count())
+
+
class TestSparkSubmit(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()