aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-03 15:52:00 -0700
committerJosh Rosen <joshrosen@apache.org>2014-08-03 15:52:00 -0700
commit55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9 (patch)
tree91277ef5bfed5d8d3177679ebfe52186350f430b /core
parente139e2be60ef23281327744e1b3e74904dfdf63f (diff)
downloadspark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.tar.gz
spark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.tar.bz2
spark-55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9.zip
[SPARK-1740] [PySpark] kill the python worker
Kill only the python worker related to cancelled tasks. The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker. When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon. Author: Davies Liu <davies.liu@gmail.com> Closes #1643 from davies/kill and squashes the following commits: 8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy 46ca150 [Davies Liu] address comment acd751c [Davies Liu] kill the worker when task is canceled
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala64
3 files changed, 56 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 92c809d854..0bce531aab 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
+import java.net.Socket
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -102,10 +103,10 @@ class SparkEnv (
}
private[spark]
- def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
+ def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
- pythonWorkers(key).stop()
+ pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fe9a9e50ef..0b5322c6fb 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
- val worker: Socket = env.createPythonWorker(pythonExec,
- envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
+ envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
+ val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
- env.destroyPythonWorker(pythonExec, envVars.toMap)
+ env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- * This function is outdated, PySpark does not use it anymore
*/
- @deprecated
+ @deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 15fe8a9be6..7af260d0b7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.python
-import java.io.{DataInputStream, InputStream, OutputStreamWriter}
+import java.lang.Runtime
+import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
+import scala.collection.mutable
import scala.collection.JavaConversions._
import org.apache.spark._
@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
+ var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+
+ var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
+
+ def createSocket(): Socket = {
+ val socket = new Socket(daemonHost, daemonPort)
+ val pid = new DataInputStream(socket.getInputStream).readInt()
+ if (pid < 0) {
+ throw new IllegalStateException("Python daemon failed to launch worker")
+ }
+ daemonWorkers.put(socket, pid)
+ socket
+ }
+
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
- val socket = new Socket(daemonHost, daemonPort)
- val launchStatus = new DataInputStream(socket.getInputStream).readInt()
- if (launchStatus != 0) {
- throw new IllegalStateException("Python daemon failed to launch worker")
- }
- socket
+ createSocket()
} catch {
case exc: SocketException =>
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
- new Socket(daemonHost, daemonPort)
+ createSocket()
}
}
}
@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
- return serverSocket.accept()
+ val socket = serverSocket.accept()
+ simpleWorkers.put(socket, worker)
+ return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
private def stopDaemon() {
synchronized {
- // Request shutdown of existing daemon by sending SIGTERM
- if (daemon != null) {
- daemon.destroy()
- }
+ if (useDaemon) {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
- daemon = null
- daemonPort = 0
+ daemon = null
+ daemonPort = 0
+ } else {
+ simpleWorkers.mapValues(_.destroy())
+ }
}
}
def stop() {
stopDaemon()
}
+
+ def stopWorker(worker: Socket) {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
+ }
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
+ }
+ worker.close()
+ }
}
private object PythonWorkerFactory {