aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala11
-rw-r--r--python/pyspark/daemon.py5
-rw-r--r--python/pyspark/serializers.py1
-rw-r--r--python/pyspark/tests.py19
-rw-r--r--python/pyspark/worker.py11
6 files changed, 44 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index aba713cb42..906a00b0bd 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -68,6 +68,7 @@ class SparkEnv (
val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
+ private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -75,6 +76,7 @@ class SparkEnv (
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
private[spark] def stop() {
+ isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 29ca751519..163dca6cad 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,6 +75,7 @@ private[spark] class PythonRDD(
var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
+ writerThread.join()
if (reuse_worker && complete_cleanly) {
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
} else {
@@ -145,7 +146,9 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
- complete_cleanly = true
+ if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+ complete_cleanly = true
+ }
null
}
} catch {
@@ -154,6 +157,10 @@ private[spark] class PythonRDD(
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
+ case e: Exception if env.isStopped =>
+ logDebug("Exception thrown after context is stopped", e)
+ null // exit silently
+
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -235,6 +242,7 @@ private[spark] class PythonRDD(
// Data values
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -306,6 +314,7 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
+ val END_OF_STREAM = -4
}
private[spark] object PythonRDD extends Logging {
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202acb..dbb34775d9 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
from pyspark.serializers import read_int, write_int
@@ -46,6 +46,9 @@ def worker(sock):
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
+ # restore the handler for SIGINT,
+ # it's useful for debugging (show the stacktrace before exit)
+ signal.signal(SIGINT, signal.default_int_handler)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d8ff..904bd9f265 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -80,6 +80,7 @@ class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
+ END_OF_STREAM = -4
class Serializer(object):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e4150e6..7a2107ec32 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,7 @@ import tempfile
import time
import zipfile
import random
-from platform import python_implementation
+import threading
if sys.version_info[:2] <= (2, 6):
try:
@@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase):
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)
+ def test_reuse_worker_after_take(self):
+ rdd = self.sc.parallelize(range(100000), 1)
+ self.assertEqual(0, rdd.first())
+
+ def count():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+
+ t = threading.Thread(target=count)
+ t.daemon = True
+ t.start()
+ t.join(5)
+ self.assertTrue(not t.isAlive())
+ self.assertEqual(100000, rdd.count())
+
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257dddfee..2bdccb5e93 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,7 +57,7 @@ def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
- return
+ exit(-1)
# initialize global state
shuffle.MemoryBytesSpilled = 0
@@ -111,7 +111,6 @@ def main(infile, outfile):
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
- outfile.flush()
except IOError:
# JVM close the socket
pass
@@ -131,6 +130,14 @@ def main(infile, outfile):
for (aid, accum) in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ exit(-1)
+
if __name__ == '__main__':
# Read a local port to connect to from stdin