aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-10-23 17:20:00 -0700
committerJosh Rosen <joshrosen@databricks.com>2014-10-23 17:20:00 -0700
commite595c8d08a20a122295af62d5e9cc4116f9727f6 (patch)
treeec0226aecad30372b9ece27e534f4482c24c94bf /python
parent83b7a1c6503adce1826fc537b4db47e534da5cae (diff)
downloadspark-e595c8d08a20a122295af62d5e9cc4116f9727f6.tar.gz
spark-e595c8d08a20a122295af62d5e9cc4116f9727f6.tar.bz2
spark-e595c8d08a20a122295af62d5e9cc4116f9727f6.zip
[SPARK-3993] [PySpark] fix bug while reuse worker after take()
After take(), maybe there are some garbage left in the socket, then next task assigned to this worker will hang because of corrupted data. We should make sure the socket is clean before reuse it, write END_OF_STREAM at the end, and check it after read out all result from python. Author: Davies Liu <davies.liu@gmail.com> Author: Davies Liu <davies@databricks.com> Closes #2838 from davies/fix_reuse and squashes the following commits: 8872914 [Davies Liu] fix tests 660875b [Davies Liu] fix bug while reuse worker after take()
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/daemon.py5
-rw-r--r--python/pyspark/serializers.py1
-rw-r--r--python/pyspark/tests.py19
-rw-r--r--python/pyspark/worker.py11
4 files changed, 32 insertions, 4 deletions
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202acb..dbb34775d9 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
from pyspark.serializers import read_int, write_int
@@ -46,6 +46,9 @@ def worker(sock):
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
+ # restore the handler for SIGINT,
+ # it's useful for debugging (show the stacktrace before exit)
+ signal.signal(SIGINT, signal.default_int_handler)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d8ff..904bd9f265 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -80,6 +80,7 @@ class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
+ END_OF_STREAM = -4
class Serializer(object):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e4150e6..7a2107ec32 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,7 @@ import tempfile
import time
import zipfile
import random
-from platform import python_implementation
+import threading
if sys.version_info[:2] <= (2, 6):
try:
@@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase):
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)
+ def test_reuse_worker_after_take(self):
+ rdd = self.sc.parallelize(range(100000), 1)
+ self.assertEqual(0, rdd.first())
+
+ def count():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+
+ t = threading.Thread(target=count)
+ t.daemon = True
+ t.start()
+ t.join(5)
+ self.assertTrue(not t.isAlive())
+ self.assertEqual(100000, rdd.count())
+
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257dddfee..2bdccb5e93 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,7 +57,7 @@ def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
- return
+ exit(-1)
# initialize global state
shuffle.MemoryBytesSpilled = 0
@@ -111,7 +111,6 @@ def main(infile, outfile):
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
- outfile.flush()
except IOError:
# JVM close the socket
pass
@@ -131,6 +130,14 @@ def main(infile, outfile):
for (aid, accum) in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ exit(-1)
+
if __name__ == '__main__':
# Read a local port to connect to from stdin