aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/daemon.py24
-rw-r--r--python/pyspark/tests.py51
2 files changed, 69 insertions, 6 deletions
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 9fde0dde0f..b00da833d0 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ from errno import EINTR, ECHILD
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
-from pyspark.serializers import write_int
+from pyspark.serializers import read_int, write_int
def compute_real_exit_code(exit_code):
@@ -67,7 +67,8 @@ def worker(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
- write_int(0, outfile) # Acknowledge that the fork was successful
+ # Acknowledge that the fork was successful
+ write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
@@ -125,14 +126,23 @@ def manager():
else:
raise
if 0 in ready_fds:
- # Spark told us to exit by closing stdin
- shutdown(0)
+ try:
+ worker_pid = read_int(sys.stdin)
+ except EOFError:
+ # Spark told us to exit by closing stdin
+ shutdown(0)
+ try:
+ os.kill(worker_pid, signal.SIGKILL)
+ except OSError:
+ pass # process already died
+
+
if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
- fork_return_code = os.fork()
- if fork_return_code == 0:
+ pid = os.fork()
+ if pid == 0:
listen_sock.close()
try:
worker(sock)
@@ -143,11 +153,13 @@ def manager():
os._exit(0)
else:
sock.close()
+
except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
+ outfile.close()
sock.close()
finally:
shutdown(1)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 16fb5a9256..acc3c30371 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+class TestWorker(PySparkTestCase):
+ def test_cancel_task(self):
+ temp = tempfile.NamedTemporaryFile(delete=True)
+ temp.close()
+ path = temp.name
+ def sleep(x):
+ import os, time
+ with open(path, 'w') as f:
+ f.write("%d %d" % (os.getppid(), os.getpid()))
+ time.sleep(100)
+
+ # start job in background thread
+ def run():
+ self.sc.parallelize(range(1)).foreach(sleep)
+ import threading
+ t = threading.Thread(target=run)
+ t.daemon = True
+ t.start()
+
+ daemon_pid, worker_pid = 0, 0
+ while True:
+ if os.path.exists(path):
+ data = open(path).read().split(' ')
+ daemon_pid, worker_pid = map(int, data)
+ break
+ time.sleep(0.1)
+
+ # cancel jobs
+ self.sc.cancelAllJobs()
+ t.join()
+
+ for i in range(50):
+ try:
+ os.kill(worker_pid, 0)
+ time.sleep(0.1)
+ except OSError:
+ break # worker was killed
+ else:
+ self.fail("worker has not been killed after 5 seconds")
+
+ try:
+ os.kill(daemon_pid, 0)
+ except OSError:
+ self.fail("daemon had been killed")
+
+ def test_fd_leak(self):
+ N = 1100 # fd limit is 1024 by default
+ rdd = self.sc.parallelize(range(N), N)
+ self.assertEquals(N, rdd.count())
+
+
class TestSparkSubmit(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()