aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/daemon.py38
-rw-r--r--python/pyspark/mllib/_common.py12
-rw-r--r--python/pyspark/serializers.py4
-rw-r--r--python/pyspark/tests.py35
-rw-r--r--python/pyspark/worker.py9
5 files changed, 70 insertions, 28 deletions
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 15445abf67..64d6202acb 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -23,6 +23,7 @@ import socket
import sys
import traceback
import time
+import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -46,17 +47,6 @@ def worker(sock):
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
- # Blocks until the socket is closed by draining the input stream
- # until it raises an exception or returns EOF.
- def waitSocketClose(sock):
- try:
- while True:
- # Empty string is returned upon EOF (and only then).
- if sock.recv(4096) == '':
- return
- except:
- pass
-
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
@@ -64,17 +54,13 @@ def worker(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
- # Acknowledge that the fork was successful
- write_int(os.getpid(), outfile)
- outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
- exit_code = exc.code
+ exit_code = compute_real_exit_code(exc.code)
finally:
outfile.flush()
- # The Scala side will close the socket upon task completion.
- waitSocketClose(sock)
- os._exit(compute_real_exit_code(exit_code))
+ if exit_code:
+ os._exit(exit_code)
# Cleanup zombie children
@@ -111,6 +97,8 @@ def manager():
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
+ reuse = os.environ.get("SPARK_REUSE_WORKER")
+
# Initialization complete
try:
while True:
@@ -163,7 +151,19 @@ def manager():
# in child process
listen_sock.close()
try:
- worker(sock)
+ # Acknowledge that the fork was successful
+ outfile = sock.makefile("w")
+ write_int(os.getpid(), outfile)
+ outfile.flush()
+ outfile.close()
+ while True:
+ worker(sock)
+ if not reuse:
+ # wait for closing
+ while sock.recv(1024):
+ pass
+ break
+ gc.collect()
except:
traceback.print_exc()
os._exit(1)
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index bb60d3d0c8..68f6033616 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -21,7 +21,7 @@ import numpy
from numpy import ndarray, float64, int64, int32, array_equal, array
from pyspark import SparkContext, RDD
from pyspark.mllib.linalg import SparseVector
-from pyspark.serializers import Serializer
+from pyspark.serializers import FramedSerializer
"""
@@ -451,18 +451,16 @@ def _serialize_rating(r):
return ba
-class RatingDeserializer(Serializer):
+class RatingDeserializer(FramedSerializer):
- def loads(self, stream):
- length = struct.unpack("!i", stream.read(4))[0]
- ba = stream.read(length)
- res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
+ def loads(self, string):
+ res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
return int(res[0]), int(res[1]), res[2]
def load_stream(self, stream):
while True:
try:
- yield self.loads(stream)
+ yield self._read_with_length(stream)
except struct.error:
return
except EOFError:
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a5f9341e81..ec3c6f0554 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -144,6 +144,8 @@ class FramedSerializer(Serializer):
def _read_with_length(self, stream):
length = read_int(stream)
+ if length == SpecialLengths.END_OF_DATA_SECTION:
+ raise EOFError
obj = stream.read(length)
if obj == "":
raise EOFError
@@ -438,6 +440,8 @@ class UTF8Deserializer(Serializer):
def loads(self, stream):
length = read_int(stream)
+ if length == SpecialLengths.END_OF_DATA_SECTION:
+ raise EOFError
s = stream.read(length)
return s.decode("utf-8") if self.use_unicode else s
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b687d695b0..747cd1767d 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1222,11 +1222,46 @@ class TestWorker(PySparkTestCase):
except OSError:
self.fail("daemon had been killed")
+ # run a normal job
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())
+ def test_after_exception(self):
+ def raise_exception(_):
+ raise Exception()
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_after_jvm_exception(self):
+ tempFile = tempfile.NamedTemporaryFile(delete=False)
+ tempFile.write("Hello World!")
+ tempFile.close()
+ data = self.sc.textFile(tempFile.name, 1)
+ filtered_data = data.filter(lambda x: True)
+ self.assertEqual(1, filtered_data.count())
+ os.unlink(tempFile.name)
+ self.assertRaises(Exception, lambda: filtered_data.count())
+
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_accumulator_when_reuse_worker(self):
+ from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+ acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
+ self.assertEqual(sum(range(100)), acc1.value)
+
+ acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
+ self.assertEqual(sum(range(100)), acc2.value)
+ self.assertEqual(sum(range(100)), acc1.value)
+
class TestSparkSubmit(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 6805063e06..61b8a74d06 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -69,9 +69,14 @@ def main(infile, outfile):
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = ser._read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, value)
+ if bid >= 0:
+ value = ser._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
+ else:
+ bid = - bid - 1
+ _broadcastRegistry.remove(bid)
+ _accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()