aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/daemon.py5
-rw-r--r--python/pyspark/serializers.py1
-rw-r--r--python/pyspark/tests.py19
-rw-r--r--python/pyspark/worker.py11
4 files changed, 32 insertions, 4 deletions
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202acb..dbb34775d9 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
from pyspark.serializers import read_int, write_int
@@ -46,6 +46,9 @@ def worker(sock):
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
+ # restore the handler for SIGINT,
+ # it's useful for debugging (show the stacktrace before exit)
+ signal.signal(SIGINT, signal.default_int_handler)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d8ff..904bd9f265 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -80,6 +80,7 @@ class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
+ END_OF_STREAM = -4
class Serializer(object):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e4150e6..7a2107ec32 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,7 @@ import tempfile
import time
import zipfile
import random
-from platform import python_implementation
+import threading
if sys.version_info[:2] <= (2, 6):
try:
@@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase):
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)
+ def test_reuse_worker_after_take(self):
+ rdd = self.sc.parallelize(range(100000), 1)
+ self.assertEqual(0, rdd.first())
+
+ def count():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+
+ t = threading.Thread(target=count)
+ t.daemon = True
+ t.start()
+ t.join(5)
+ self.assertTrue(not t.isAlive())
+ self.assertEqual(100000, rdd.count())
+
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257dddfee..2bdccb5e93 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,7 +57,7 @@ def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
- return
+ exit(-1)
# initialize global state
shuffle.MemoryBytesSpilled = 0
@@ -111,7 +111,6 @@ def main(infile, outfile):
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
- outfile.flush()
except IOError:
# JVM close the socket
pass
@@ -131,6 +130,14 @@ def main(infile, outfile):
for (aid, accum) in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ exit(-1)
+
if __name__ == '__main__':
# Read a local port to connect to from stdin