diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/daemon.py | 5 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 1 | ||||
-rw-r--r-- | python/pyspark/tests.py | 19 | ||||
-rw-r--r-- | python/pyspark/worker.py | 11 |
4 files changed, 32 insertions, 4 deletions
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 64d6202acb..dbb34775d9 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -26,7 +26,7 @@ import time import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN -from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -46,6 +46,9 @@ def worker(sock): signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) + # restore the handler for SIGINT, + # it's useful for debugging (show the stacktrace before exit) + signal.signal(SIGINT, signal.default_int_handler) # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 08a0f0d8ff..904bd9f265 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -80,6 +80,7 @@ class SpecialLengths(object): END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 + END_OF_STREAM = -4 class Serializer(object): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1a8e4150e6..7a2107ec32 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,7 +31,7 @@ import tempfile import time import zipfile import random -from platform import python_implementation +import threading if sys.version_info[:2] <= (2, 6): try: @@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase): self.assertEqual(sum(range(100)), acc2.value) self.assertEqual(sum(range(100)), acc1.value) + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(range(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8257dddfee..2bdccb5e93 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,7 +57,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - return + exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 @@ -111,7 +111,6 @@ def main(infile, outfile): try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) - outfile.flush() except IOError: # JVM close the socket pass @@ -131,6 +130,14 @@ def main(infile, outfile): for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + exit(-1) + if __name__ == '__main__': # Read a local port to connect to from stdin |