aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-13 16:22:04 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-13 16:22:04 -0700
commit2aea0da84c58a179917311290083456dfa043db7 (patch)
tree6cda208e50f24c31883f1fdf2f51b7a6a8399ff1 /core
parent0f8c4edf4e750e3d11da27cc22c40b0489da7f37 (diff)
downloadspark-2aea0da84c58a179917311290083456dfa043db7.tar.gz
spark-2aea0da84c58a179917311290083456dfa043db7.tar.bz2
spark-2aea0da84c58a179917311290083456dfa043db7.zip
[SPARK-3030] [PySpark] Reuse Python worker
Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts. This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming. For a job with broadcast (43M after compress): ``` b = sc.broadcast(set(range(30000000))) print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count() ``` It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks. It's enabled by default, could be disabled by `spark.python.worker.reuse = false`. Author: Davies Liu <davies.liu@gmail.com> Closes #2259 from davies/reuse-worker and squashes the following commits: f11f617 [Davies Liu] Merge branch 'master' into reuse-worker 3939f20 [Davies Liu] fix bug in serializer in mllib cf1c55e [Davies Liu] address comments 3133a60 [Davies Liu] fix accumulator with reused worker 760ab1f [Davies Liu] do not reuse worker if there are any exceptions 7abb224 [Davies Liu] refactor: sychronized with itself ac3206e [Davies Liu] renaming 8911f44 [Davies Liu] synchronized getWorkerBroadcasts() 6325fc1 [Davies Liu] bugfix: bid >= 0 e0131a2 [Davies Liu] fix name of config 583716e [Davies Liu] only reuse completed and not interrupted worker ace2917 [Davies Liu] kill python worker after timeout 6123d0f [Davies Liu] track broadcasts for each worker 8d2f08c [Davies Liu] reuse python worker
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala58
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala85
3 files changed, 128 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd95e406f2..009ed64775 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -108,6 +108,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
+
+ private[spark]
+ def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+ }
+ }
}
object SparkEnv extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ae8010300a..ca8eef5f99 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
override def getPartitions = parent.partitions
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+ if (reuse_worker) {
+ envVars += ("SPARK_REUSE_WORKER" -> "1")
+ }
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
+ var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
-
- // Cleanup the worker socket. This will also cause the Python worker to exit.
- try {
- worker.close()
- } catch {
- case e: Exception => logWarning("Failed to close worker socket", e)
+ if (reuse_worker && complete_cleanly) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ } else {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
}
}
@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
+ complete_cleanly = true
null
}
} catch {
@@ -195,11 +205,26 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
- dataOut.writeInt(broadcastVars.length)
+ val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+ val newBids = broadcastVars.map(_.id).toSet
+ // number of different broadcasts
+ val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+ dataOut.writeInt(cnt)
+ for (bid <- oldBids) {
+ if (!newBids.contains(bid)) {
+ // remove the broadcast from worker
+ dataOut.writeLong(- bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
+ }
for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ oldBids.add(broadcast.id)
+ }
}
dataOut.flush()
// Serialized command:
@@ -207,17 +232,18 @@ private[spark] class PythonRDD(
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+ worker.shutdownOutput()
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
- } finally {
- Try(worker.shutdownOutput()) // kill Python worker process
+ worker.shutdownOutput()
}
}
}
@@ -278,6 +304,14 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")
+ // remember the broadcasts sent to each worker
+ private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ private def getWorkerBroadcasts(worker: Socket) = {
+ synchronized {
+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+ }
+ }
+
/**
* Adapter for calling SparkContext#runJob from Python.
*
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4c4796f6c5..71bdf0fe1b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
- var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val idleWorkers = new mutable.Queue[Socket]()
+ var lastActivity = 0L
+ new MonitorThread().start()
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
def create(): Socket = {
if (useDaemon) {
+ synchronized {
+ if (idleWorkers.size > 0) {
+ return idleWorkers.dequeue()
+ }
+ }
createThroughDaemon()
} else {
createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
+ /**
+ * Monitor all the idle workers, kill them after timeout.
+ */
+ private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+ setDaemon(true)
+
+ override def run() {
+ while (true) {
+ synchronized {
+ if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+ cleanupIdleWorkers()
+ lastActivity = System.currentTimeMillis()
+ }
+ }
+ Thread.sleep(10000)
+ }
+ }
+ }
+
+ private def cleanupIdleWorkers() {
+ while (idleWorkers.length > 0) {
+ val worker = idleWorkers.dequeue()
+ try {
+ // the worker will exit after closing the socket
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
private def stopDaemon() {
synchronized {
if (useDaemon) {
+ cleanupIdleWorkers()
+
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
def stopWorker(worker: Socket) {
- if (useDaemon) {
- if (daemon != null) {
- daemonWorkers.get(worker).foreach { pid =>
- // tell daemon to kill worker by pid
- val output = new DataOutputStream(daemon.getOutputStream)
- output.writeInt(pid)
- output.flush()
- daemon.getOutputStream.flush()
+ synchronized {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
}
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
}
- } else {
- simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
+
+ def releaseWorker(worker: Socket) {
+ if (useDaemon) {
+ synchronized {
+ lastActivity = System.currentTimeMillis()
+ idleWorkers.enqueue(worker)
+ }
+ } else {
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
}
private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
+ val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}