aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <reynoldx@gmail.com>2013-07-23 15:28:20 -0700
committerReynold Xin <reynoldx@gmail.com>2013-07-23 15:30:20 -0700
commitf2422d4f29abb80b9bc76c4596d1cc31d9e6d590 (patch)
tree7caec35eeeff5038cc7481d964218edf908e740e /core
parent5364f645c5bc515a74edf65ca064e4265f97243d (diff)
downloadspark-f2422d4f29abb80b9bc76c4596d1cc31d9e6d590.tar.gz
spark-f2422d4f29abb80b9bc76c4596d1cc31d9e6d590.tar.bz2
spark-f2422d4f29abb80b9bc76c4596d1cc31d9e6d590.zip
SPARK-829: scheduler shouldn't hang if a task contains unserializable objects in its closure.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala29
-rw-r--r--core/src/test/scala/spark/FailureSuite.scala29
2 files changed, 42 insertions, 16 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 29e879aa42..5fcd807aff 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -17,19 +17,17 @@
package spark.scheduler
-import cluster.TaskInfo
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
+import java.io.NotSerializableException
import java.util.Properties
+import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import spark._
import spark.executor.TaskMetrics
-import spark.partial.ApproximateActionListener
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
+import spark.scheduler.cluster.TaskInfo
import spark.storage.{BlockManager, BlockManagerMaster}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
@@ -258,7 +256,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
+ properties)
return (toSubmit, waiter)
}
@@ -283,7 +282,7 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
- val (toSubmit, waiter) = prepareJob(
+ val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
@@ -466,6 +465,18 @@ class DAGScheduler(
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
+
+ // Preemptively serialize the stage RDD to make sure the tasks for this stage will be
+ // serializable. We are catching this exception here because it would be fairly hard to
+ // catch the non-serializable exception down the road, where we have several different
+ // implementations for local scheduler and cluster schedulers.
+ try {
+ SparkEnv.get.closureSerializer.newInstance().serialize(stage.rdd)
+ } catch {
+ case e: NotSerializableException => abortStage(stage, e.toString)
+ return
+ }
+
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index 6c847b8fef..c3c52f9118 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -18,9 +18,6 @@
package spark
import org.scalatest.FunSuite
-import org.scalatest.prop.Checkers
-
-import scala.collection.mutable.ArrayBuffer
import SparkContext._
@@ -40,7 +37,7 @@ object FailureSuiteState {
}
class FailureSuite extends FunSuite with LocalSparkContext {
-
+
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
@@ -66,7 +63,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
test("failure in a two-stage job") {
sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
- case (k, v) =>
+ case (k, v) =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
if (k == 1 && FailureSuiteState.tasksFailed == 0) {
@@ -87,15 +84,33 @@ class FailureSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local[1,1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
- val thrown = intercept[spark.SparkException] {
+ val thrown = intercept[SparkException] {
results.collect()
}
- assert(thrown.getClass === classOf[spark.SparkException])
+ assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
FailureSuiteState.clear()
}
+ test("failure because task closure is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+ val thrown = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).map(x => a).count()
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("NotSerializableException"))
+
+ val thrown1 = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
+ }
+ assert(thrown1.getClass === classOf[SparkException])
+ assert(thrown1.getMessage.contains("NotSerializableException"))
+
+ FailureSuiteState.clear()
+ }
+
// TODO: Need to add tests with shuffle fetch failures.
}