aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala58
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala85
3 files changed, 128 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd95e406f2..009ed64775 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -108,6 +108,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
+
+ private[spark]
+ def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+ }
+ }
}
object SparkEnv extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ae8010300a..ca8eef5f99 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
override def getPartitions = parent.partitions
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+ if (reuse_worker) {
+ envVars += ("SPARK_REUSE_WORKER" -> "1")
+ }
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
+ var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
-
- // Cleanup the worker socket. This will also cause the Python worker to exit.
- try {
- worker.close()
- } catch {
- case e: Exception => logWarning("Failed to close worker socket", e)
+ if (reuse_worker && complete_cleanly) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ } else {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
}
}
@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
+ complete_cleanly = true
null
}
} catch {
@@ -195,11 +205,26 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
- dataOut.writeInt(broadcastVars.length)
+ val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+ val newBids = broadcastVars.map(_.id).toSet
+ // number of different broadcasts
+ val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+ dataOut.writeInt(cnt)
+ for (bid <- oldBids) {
+ if (!newBids.contains(bid)) {
+ // remove the broadcast from worker
+ dataOut.writeLong(- bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
+ }
for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ oldBids.add(broadcast.id)
+ }
}
dataOut.flush()
// Serialized command:
@@ -207,17 +232,18 @@ private[spark] class PythonRDD(
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+ worker.shutdownOutput()
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
- } finally {
- Try(worker.shutdownOutput()) // kill Python worker process
+ worker.shutdownOutput()
}
}
}
@@ -278,6 +304,14 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")
+ // remember the broadcasts sent to each worker
+ private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ private def getWorkerBroadcasts(worker: Socket) = {
+ synchronized {
+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+ }
+ }
+
/**
* Adapter for calling SparkContext#runJob from Python.
*
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4c4796f6c5..71bdf0fe1b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
- var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val idleWorkers = new mutable.Queue[Socket]()
+ var lastActivity = 0L
+ new MonitorThread().start()
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
def create(): Socket = {
if (useDaemon) {
+ synchronized {
+ if (idleWorkers.size > 0) {
+ return idleWorkers.dequeue()
+ }
+ }
createThroughDaemon()
} else {
createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
+ /**
+ * Monitor all the idle workers, kill them after timeout.
+ */
+ private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+ setDaemon(true)
+
+ override def run() {
+ while (true) {
+ synchronized {
+ if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+ cleanupIdleWorkers()
+ lastActivity = System.currentTimeMillis()
+ }
+ }
+ Thread.sleep(10000)
+ }
+ }
+ }
+
+ private def cleanupIdleWorkers() {
+ while (idleWorkers.length > 0) {
+ val worker = idleWorkers.dequeue()
+ try {
+ // the worker will exit after closing the socket
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
private def stopDaemon() {
synchronized {
if (useDaemon) {
+ cleanupIdleWorkers()
+
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
def stopWorker(worker: Socket) {
- if (useDaemon) {
- if (daemon != null) {
- daemonWorkers.get(worker).foreach { pid =>
- // tell daemon to kill worker by pid
- val output = new DataOutputStream(daemon.getOutputStream)
- output.writeInt(pid)
- output.flush()
- daemon.getOutputStream.flush()
+ synchronized {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
}
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
}
- } else {
- simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
+
+ def releaseWorker(worker: Socket) {
+ if (useDaemon) {
+ synchronized {
+ lastActivity = System.currentTimeMillis()
+ idleWorkers.enqueue(worker)
+ }
+ } else {
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
}
private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
+ val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}