aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-02 00:03:51 -0700
committerMatei Zaharia <matei@databricks.com>2014-11-02 00:03:51 -0700
commit6181577e9935f46b646ba3925b873d031aa3d6ba (patch)
tree84a704bd54be30393f71e744351e2e74903a311c
parent23f966f47523f85ba440b4080eee665271f53b5e (diff)
downloadspark-6181577e9935f46b646ba3925b873d031aa3d6ba.tar.gz
spark-6181577e9935f46b646ba3925b873d031aa3d6ba.tar.bz2
spark-6181577e9935f46b646ba3925b873d031aa3d6ba.zip
[SPARK-3466] Limit size of results that a driver collects for each action
Right now, operations like collect() and take() can crash the driver with an OOM if they bring back too many data. This PR will introduce spark.driver.maxResultSize, after setting it, the driver will abort a job if its result is bigger than it. By default, it's 1g (for backward compatibility for most the cases). In local mode, the driver and executor share the same JVM, the default setting can not protect JVM from OOM. cc mateiz Author: Davies Liu <davies@databricks.com> Closes #3003 from davies/collect and squashes the following commits: 248ed5e [Davies Liu] fix compile 272522e [Davies Liu] address comments 2c35773 [Davies Liu] add sizes in message of abort() 5d62303 [Davies Liu] address comments bc3c077 [Davies Liu] Merge branch 'master' of github.com:apache/spark into collect 11f97c5 [Davies Liu] address comments 47b144f [Davies Liu] check the size of result before send and fetch 3d81af2 [Davies Liu] address comments ca8267d [Davies Liu] limit the size of data by collect
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala25
-rw-r--r--docs/configuration.md12
8 files changed, 104 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c78e0ffca2..e24a15f015 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -104,6 +104,9 @@ private[spark] class Executor(
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ // Limit of bytes for total size of results (default is 1GB)
+ private val maxResultSize = Utils.getMaxResultSize(conf)
+
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -210,25 +213,27 @@ private[spark] class Executor(
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val (serializedResult, directSend) = {
- if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedResult = {
+ if (resultSize > maxResultSize) {
+ logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
+ s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
+ s"dropping it.")
+ ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+ } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
- (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+ logInfo(
+ s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
+ ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
- (serializedDirectResult, true)
+ logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
+ serializedDirectResult
}
}
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
- if (directSend) {
- logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
- } else {
- logInfo(
- s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
- }
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 11c19eeb6e..1f114a0207 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+ extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 4b5be68ec5..819b51e12a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) =>
+ val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] =>
+ if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ return
+ }
+ (directResult, serializedData.limit())
+ case IndirectTaskResult(blockId, size) =>
+ if (!taskSetManager.canFetchMoreResults(size)) {
+ // dropped by executor if size is larger than maxResultSize
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ return
+ }
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
- deserializedResult
+ (deserializedResult, size)
}
- result.metrics.resultSize = serializedData.limit()
+
+ result.metrics.resultSize = size
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 376821f89c..a976734007 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -23,13 +23,12 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -68,6 +67,9 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
+ // Limit of bytes for total size of results (default is 1GB)
+ val maxResultSize = Utils.getMaxResultSize(conf)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -89,6 +91,8 @@ private[spark] class TaskSetManager(
var stageId = taskSet.stageId
var name = "TaskSet_" + taskSet.stageId.toString
var parent: Pool = null
+ var totalResultSize = 0L
+ var calculatedTasks = 0
val runningTasksSet = new HashSet[Long]
override def runningTasks = runningTasksSet.size
@@ -515,6 +519,9 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Marks the task as getting result and notifies the DAG Scheduler
+ */
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
@@ -522,6 +529,24 @@ private[spark] class TaskSetManager(
}
/**
+ * Check whether has enough quota to fetch the result with `size` bytes
+ */
+ def canFetchMoreResults(size: Long): Boolean = synchronized {
+ totalResultSize += size
+ calculatedTasks += 1
+ if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+ val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " +
+ s"(${Utils.bytesToString(maxResultSize)})"
+ logError(msg)
+ abort(msg)
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 68d378f3a2..4e30d0d381 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1720,6 +1720,11 @@ private[spark] object Utils extends Logging {
method.invoke(obj, values.toSeq: _*)
}
+ // Limit of bytes for total size of results (default is 1GB)
+ def getMaxResultSize(conf: SparkConf): Long = {
+ memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20
+ }
+
/**
* Return the current system LD_LIBRARY_PATH name
*/
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index c4e7a4bb7d..5768a3a733 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
// Only remove the result once, since we'd like to test the case where the task eventually
// succeeds.
serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case IndirectTaskResult(blockId) =>
+ case IndirectTaskResult(blockId, size) =>
sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] =>
taskSetManager.abort("Internal error: expect only indirect results")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c0b07649eb..1809b5396d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.emittedTaskSizeWarning)
}
+ test("abort the job if total size of results is too large") {
+ val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
+ sc = new SparkContext("local", "test", conf)
+
+ def genBytes(size: Int) = { (x: Int) =>
+ val bytes = Array.ofDim[Byte](size)
+ scala.util.Random.nextBytes(bytes)
+ bytes
+ }
+
+ // multiple 1k result
+ val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect()
+ assert(10 === r.size )
+
+ // single 10M result
+ val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()}
+ assert(thrown.getMessage().contains("bigger than maxResultSize"))
+
+ // multiple 1M results
+ val thrown2 = intercept[SparkException] {
+ sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect()
+ }
+ assert(thrown2.getMessage().contains("bigger than maxResultSize"))
+ }
+
test("speculative and noPref task should be scheduled after node-local") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
diff --git a/docs/configuration.md b/docs/configuration.md
index 3007706a25..099972ca1a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -112,6 +112,18 @@ of the most common options to set are:
</td>
</tr>
<tr>
+ <td><code>spark.driver.maxResultSize</code></td>
+ <td>1g</td>
+ <td>
+ Limit of total size of serialized results of all partitions for each Spark action (e.g. collect).
+ Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size
+ is above this limit.
+ Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory
+ and memory overhead of objects in JVM). Setting a proper limit can protect the driver from
+ out-of-memory errors.
+ </td>
+</tr>
+<tr>
<td><code>spark.serializer</code></td>
<td>org.apache.spark.serializer.<br />JavaSerializer</td>
<td>