diff options
author | Davies Liu <davies@databricks.com> | 2014-11-02 00:03:51 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-11-02 00:03:51 -0700 |
commit | 6181577e9935f46b646ba3925b873d031aa3d6ba (patch) | |
tree | 84a704bd54be30393f71e744351e2e74903a311c | |
parent | 23f966f47523f85ba440b4080eee665271f53b5e (diff) | |
download | spark-6181577e9935f46b646ba3925b873d031aa3d6ba.tar.gz spark-6181577e9935f46b646ba3925b873d031aa3d6ba.tar.bz2 spark-6181577e9935f46b646ba3925b873d031aa3d6ba.zip |
[SPARK-3466] Limit size of results that a driver collects for each action
Right now, operations like collect() and take() can crash the driver with an OOM if they bring back too many data.
This PR will introduce spark.driver.maxResultSize, after setting it, the driver will abort a job if its result is bigger than it.
By default, it's 1g (for backward compatibility for most the cases).
In local mode, the driver and executor share the same JVM, the default setting can not protect JVM from OOM.
cc mateiz
Author: Davies Liu <davies@databricks.com>
Closes #3003 from davies/collect and squashes the following commits:
248ed5e [Davies Liu] fix compile
272522e [Davies Liu] address comments
2c35773 [Davies Liu] add sizes in message of abort()
5d62303 [Davies Liu] address comments
bc3c077 [Davies Liu] Merge branch 'master' of github.com:apache/spark into collect
11f97c5 [Davies Liu] address comments
47b144f [Davies Liu] check the size of result before send and fetch
3d81af2 [Davies Liu] address comments
ca8267d [Davies Liu] limit the size of data by collect
8 files changed, 104 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c78e0ffca2..e24a15f015 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -104,6 +104,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -210,25 +213,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11c19eeb6e..1f114a0207 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 4b5be68ec5..819b51e12a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 376821f89c..a976734007 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,6 +519,9 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() @@ -522,6 +529,24 @@ private[spark] class TaskSetManager( } /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 68d378f3a2..4e30d0d381 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1720,6 +1720,11 @@ private[spark] object Utils extends Logging { method.invoke(obj, values.toSeq: _*) } + // Limit of bytes for total size of results (default is 1GB) + def getMaxResultSize(conf: SparkConf): Long = { + memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 + } + /** * Return the current system LD_LIBRARY_PATH name */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index c4e7a4bb7d..5768a3a733 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. serializer.get().deserialize[TaskResult[_]](serializedData) match { - case IndirectTaskResult(blockId) => + case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c0b07649eb..1809b5396d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("abort the job if total size of results is too large") { + val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") + sc = new SparkContext("local", "test", conf) + + def genBytes(size: Int) = { (x: Int) => + val bytes = Array.ofDim[Byte](size) + scala.util.Random.nextBytes(bytes) + bytes + } + + // multiple 1k result + val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() + assert(10 === r.size ) + + // single 10M result + val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} + assert(thrown.getMessage().contains("bigger than maxResultSize")) + + // multiple 1M results + val thrown2 = intercept[SparkException] { + sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() + } + assert(thrown2.getMessage().contains("bigger than maxResultSize")) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) diff --git a/docs/configuration.md b/docs/configuration.md index 3007706a25..099972ca1a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -112,6 +112,18 @@ of the most common options to set are: </td> </tr> <tr> + <td><code>spark.driver.maxResultSize</code></td> + <td>1g</td> + <td> + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. + </td> +</tr> +<tr> <td><code>spark.serializer</code></td> <td>org.apache.spark.serializer.<br />JavaSerializer</td> <td> |