aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala217
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala10
-rw-r--r--python/pyspark/context.py2
-rw-r--r--python/pyspark/daemon.py14
5 files changed, 141 insertions, 107 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index dc012cc381..fc4812753d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -42,9 +42,13 @@ class TaskContext(
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ // Set to true when the task is completed, before the onCompleteCallbacks are executed.
+ @volatile var completed: Boolean = false
+
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
+ * Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
def addOnCompleteCallback(f: () => Unit) {
@@ -52,6 +56,7 @@ class TaskContext(
}
def executeOnCompleteCallbacks() {
+ completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach{_()}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 6140700708..fecd9762f3 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
val env = SparkEnv.get
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
- // Ensure worker socket is closed on task completion. Closing sockets is idempotent.
- context.addOnCompleteCallback(() =>
+ // Start a thread to feed the process input from our parent's iterator
+ val writerThread = new WriterThread(env, worker, split, context)
+
+ context.addOnCompleteCallback { () =>
+ writerThread.shutdownOnTaskCompletion()
+
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
}
- )
-
- @volatile var readerException: Exception = null
-
- // Start a thread to feed the process input from our parent's iterator
- new Thread("stdin writer for " + pythonExec) {
- override def run() {
- try {
- SparkEnv.set(env)
- val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
- val dataOut = new DataOutputStream(stream)
- // Partition index
- dataOut.writeInt(split.index)
- // sparkFilesDir
- PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
- // Broadcast variables
- dataOut.writeInt(broadcastVars.length)
- for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
- }
- // Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.length)
- for (include <- pythonIncludes) {
- PythonRDD.writeUTF(include, dataOut)
- }
- dataOut.flush()
- // Serialized command:
- dataOut.writeInt(command.length)
- dataOut.write(command)
- // Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
- dataOut.flush()
- worker.shutdownOutput()
- } catch {
-
- case e: java.io.FileNotFoundException =>
- readerException = e
- Try(worker.shutdownOutput()) // kill Python worker process
-
- case e: IOException =>
- // This can happen for legitimate reasons if the Python code stops returning data
- // before we are done passing elements through, e.g., for take(). Just log a message to
- // say it happened (as it could also be hiding a real IOException from a data source).
- logInfo("stdin writer to Python finished early (may not be an error)", e)
-
- case e: Exception =>
- // We must avoid throwing exceptions here, because the thread uncaught exception handler
- // will kill the whole executor (see Executor).
- readerException = e
- Try(worker.shutdownOutput()) // kill Python worker process
- }
- }
- }.start()
-
- // Necessary to distinguish between a task that has failed and a task that is finished
- @volatile var complete: Boolean = false
-
- // It is necessary to have a monitor thread for python workers if the user cancels with
- // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
- // threads can block indefinitely.
- new Thread(s"Worker Monitor for $pythonExec") {
- override def run() {
- // Kill the worker if it is interrupted or completed
- // When a python task completes, the context is always set to interupted
- while (!context.interrupted) {
- Thread.sleep(2000)
- }
- if (!complete) {
- try {
- logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
- env.destroyPythonWorker(pythonExec, envVars.toMap)
- } catch {
- case e: Exception =>
- logError("Exception when trying to kill worker", e)
- }
- }
- }
- }.start()
-
- /*
- * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
- * other completion callbacks might invalidate the input. Because interruption
- * is not synchronous this still leaves a potential race where the interruption is
- * processed only after the stream becomes invalid.
- */
- context.addOnCompleteCallback{ () =>
- complete = true // Indicate that the task has completed successfully
- context.interrupted = true
}
+ writerThread.start()
+ new MonitorThread(env, worker, context).start()
+
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
- // FIXME: can deadlock if worker is waiting for us to
- // respond to current message (currently irrelevant because
- // output is shutdown before we read any input)
_nextObj = read()
}
obj
}
private def read(): Array[Byte] = {
- if (readerException != null) {
- throw readerException
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
}
try {
stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
- read
+ read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj, "utf-8"), readerException)
+ throw new PythonException(new String(obj, "utf-8"),
+ writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
Array.empty[Byte]
}
} catch {
- case e: Exception if readerException != null =>
+
+ case e: Exception if context.interrupted =>
+ logDebug("Exception thrown after task interruption", e)
+ throw new TaskKilledException
+
+ case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
- logError("Python crash may have been caused by prior exception:", readerException)
- throw readerException
+ logError("This may have been caused by a prior exception:", writerThread.exception.get)
+ throw writerThread.exception.get
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
def hasNext = _nextObj.length != 0
}
- stdoutIterator
+ new InterruptibleIterator(context, stdoutIterator)
}
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+
+ /**
+ * The thread responsible for writing the data from the PythonRDD's parent iterator to the
+ * Python process.
+ */
+ class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
+ extends Thread(s"stdout writer for $pythonExec") {
+
+ @volatile private var _exception: Exception = null
+
+ setDaemon(true)
+
+ /** Contains the exception thrown while writing the parent iterator to the Python process. */
+ def exception: Option[Exception] = Option(_exception)
+
+ /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
+ def shutdownOnTaskCompletion() {
+ assert(context.completed)
+ this.interrupt()
+ }
+
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ // Partition index
+ dataOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
+ // Broadcast variables
+ dataOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ }
+ // Python includes (*.zip and *.egg files)
+ dataOut.writeInt(pythonIncludes.length)
+ for (include <- pythonIncludes) {
+ PythonRDD.writeUTF(include, dataOut)
+ }
+ dataOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
+ // Data values
+ PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ dataOut.flush()
+ } catch {
+ case e: Exception if context.completed || context.interrupted =>
+ logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+
+ case e: Exception =>
+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
+ // will kill the whole executor (see org.apache.spark.executor.Executor).
+ _exception = e
+ } finally {
+ Try(worker.shutdownOutput()) // kill Python worker process
+ }
+ }
+ }
+
+ /**
+ * It is necessary to have a monitor thread for python workers if the user cancels with
+ * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
+ * threads can block indefinitely.
+ */
+ class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+ extends Thread(s"Worker Monitor for $pythonExec") {
+
+ setDaemon(true)
+
+ override def run() {
+ // Kill the worker if it is interrupted, checking until task completion.
+ // TODO: This has a race condition if interruption occurs, as completed may still become true.
+ while (!context.interrupted && !context.completed) {
+ Thread.sleep(2000)
+ }
+ if (!context.completed) {
+ try {
+ logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
+ env.destroyPythonWorker(pythonExec, envVars.toMap)
+ } catch {
+ case e: Exception =>
+ logError("Exception when trying to kill worker", e)
+ }
+ }
+ }
+ }
}
/** Thrown for exceptions in user Python code. */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 02b62de7e3..2259df0b56 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,11 +17,13 @@
package org.apache.spark.scheduler
+import scala.language.existentials
+
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
-import scala.language.existentials
+import scala.util.Try
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
@@ -196,7 +198,11 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
- shuffle.releaseWriters(success)
+ try {
+ shuffle.releaseWriters(success)
+ } catch {
+ case e: Exception => logError("Failed to release shuffle writers", e)
+ }
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index c7dc85ea03..cac133d0fc 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -453,7 +453,7 @@ class SparkContext(object):
>>> lock = threading.Lock()
>>> def map_func(x):
... sleep(100)
- ... return x * x
+ ... raise Exception("Task should have been cancelled")
>>> def start_job(x):
... global result
... try:
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index eb18ec08c9..b2f226a55e 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -74,6 +74,17 @@ def worker(listen_sock):
raise
signal.signal(SIGCHLD, handle_sigchld)
+ # Blocks until the socket is closed by draining the input stream
+ # until it raises an exception or returns EOF.
+ def waitSocketClose(sock):
+ try:
+ while True:
+ # Empty string is returned upon EOF (and only then).
+ if sock.recv(4096) == '':
+ return
+ except:
+ pass
+
# Handle clients
while not should_exit():
# Wait until a client arrives or we have to exit
@@ -105,7 +116,8 @@ def worker(listen_sock):
exit_code = exc.code
finally:
outfile.flush()
- sock.close()
+ # The Scala side will close the socket upon task completion.
+ waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()