aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2014-06-29 23:27:34 -0700
committerReynold Xin <rxin@apache.org>2014-06-29 23:27:34 -0700
commita484030dae9d0d7e4b97cc6307e9e928c07490dc (patch)
tree8077c51ced4a8a14f93a969798706ff39d948dd2
parent66135a341d9f8baecc149d13ae5511f14578c395 (diff)
downloadspark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.tar.gz
spark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.tar.bz2
spark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.zip
SPARK-897: preemptively serialize closures
These commits cause `ClosureCleaner.clean` to attempt to serialize the cleaned closure with the default closure serializer and throw a `SparkException` if doing so fails. This behavior is enabled by default but can be disabled at individual callsites of `SparkContext.clean`. Commit 98e01ae8 fixes some no-op assertions in `GraphSuite` that this work exposed; I'm happy to put that in a separate PR if that would be more appropriate. Author: William Benton <willb@redhat.com> Closes #143 from willb/spark-897 and squashes the following commits: bceab8a [William Benton] Commented DStream corner cases for serializability checking. 64d04d2 [William Benton] FailureSuite now checks both messages and causes. 3b3f74a [William Benton] Stylistic and doc cleanups. b215dea [William Benton] Fixed spurious failures in ImplicitOrderingSuite be1ecd6 [William Benton] Don't check serializability of DStream transforms. abe816b [William Benton] Make proactive serializability checking optional. 5bfff24 [William Benton] Adds proactive closure-serializablilty checking ed2ccf0 [William Benton] Test cases for SPARK-897.
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala75
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala90
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala25
6 files changed, 196 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f9476ff826..8819e73d17 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1203,9 +1203,17 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
+ * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
+ * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
+ * if not.
+ *
+ * @param f the closure to clean
+ * @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
+ * @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
+ * serializable
*/
- private[spark] def clean[F <: AnyRef](f: F): F = {
- ClosureCleaner.clean(f)
+ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
+ ClosureCleaner.clean(f, checkSerializable)
f
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 4916d9b86c..e3f52f6ff1 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.Set
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{Logging, SparkEnv, SparkException}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -101,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- def clean(func: AnyRef) {
+ def clean(func: AnyRef, checkSerializable: Boolean = true) {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
@@ -153,6 +153,18 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}
+
+ if (checkSerializable) {
+ ensureSerializable(func)
+ }
+ }
+
+ private def ensureSerializable(func: AnyRef) {
+ try {
+ SparkEnv.get.closureSerializer.newInstance().serialize(func)
+ } catch {
+ case ex: Exception => throw new SparkException("Task not serializable", ex)
+ }
}
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 12dbebcb28..e755d2e309 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -22,6 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.util.NonSerializable
+import java.io.NotSerializableException
+
// Common state shared by FailureSuite-launched tasks. We use a global object
// for this because any local variables used in the task closures will rightfully
// be copied for each task, so there's no other way for them to share state.
@@ -102,7 +104,8 @@ class FailureSuite extends FunSuite with LocalSparkContext {
results.collect()
}
assert(thrown.getClass === classOf[SparkException])
- assert(thrown.getMessage.contains("NotSerializableException"))
+ assert(thrown.getMessage.contains("NotSerializableException") ||
+ thrown.getCause.getClass === classOf[NotSerializableException])
FailureSuiteState.clear()
}
@@ -116,21 +119,24 @@ class FailureSuite extends FunSuite with LocalSparkContext {
sc.parallelize(1 to 10, 2).map(x => a).count()
}
assert(thrown.getClass === classOf[SparkException])
- assert(thrown.getMessage.contains("NotSerializableException"))
+ assert(thrown.getMessage.contains("NotSerializableException") ||
+ thrown.getCause.getClass === classOf[NotSerializableException])
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
- assert(thrown1.getMessage.contains("NotSerializableException"))
+ assert(thrown1.getMessage.contains("NotSerializableException") ||
+ thrown1.getCause.getClass === classOf[NotSerializableException])
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
}
assert(thrown2.getClass === classOf[SparkException])
- assert(thrown2.getMessage.contains("NotSerializableException"))
+ assert(thrown2.getMessage.contains("NotSerializableException") ||
+ thrown2.getCause.getClass === classOf[NotSerializableException])
FailureSuiteState.clear()
}
diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
index 4bd8891356..8e4a9e2c9f 100644
--- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala
@@ -19,9 +19,29 @@ package org.apache.spark
import org.scalatest.FunSuite
+import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
+ // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
+ test("basic inference of Orderings"){
+ sc = new SparkContext("local", "test")
+ val rdd = sc.parallelize(1 to 10)
+
+ // These RDD methods are in the companion object so that the unserializable ScalaTest Engine
+ // won't be reachable from the closure object
+
+ // Infer orderings after basic maps to particular types
+ val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd)
+ basicMapExpectations.map({case (met, explain) => assert(met, explain)})
+
+ // Infer orderings for other RDD methods
+ val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd)
+ otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)})
+ }
+}
+
+private object ImplicitOrderingSuite {
class NonOrderedClass {}
class ComparableClass extends Comparable[ComparableClass] {
@@ -31,27 +51,36 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
class OrderedClass extends Ordered[OrderedClass] {
override def compare(o: OrderedClass): Int = ???
}
-
- // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
- test("basic inference of Orderings"){
- sc = new SparkContext("local", "test")
- val rdd = sc.parallelize(1 to 10)
-
- // Infer orderings after basic maps to particular types
- assert(rdd.map(x => (x, x)).keyOrdering.isDefined)
- assert(rdd.map(x => (1, x)).keyOrdering.isDefined)
- assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined)
- assert(rdd.map(x => (null, x)).keyOrdering.isDefined)
- assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty)
- assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined)
- assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined)
-
- // Infer orderings for other RDD methods
- assert(rdd.groupBy(x => x).keyOrdering.isDefined)
- assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty)
- assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined)
- assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined)
- assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined)
- assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined)
+
+ def basicMapExpectations(rdd: RDD[Int]) = {
+ List((rdd.map(x => (x, x)).keyOrdering.isDefined,
+ "rdd.map(x => (x, x)).keyOrdering.isDefined"),
+ (rdd.map(x => (1, x)).keyOrdering.isDefined,
+ "rdd.map(x => (1, x)).keyOrdering.isDefined"),
+ (rdd.map(x => (x.toString, x)).keyOrdering.isDefined,
+ "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"),
+ (rdd.map(x => (null, x)).keyOrdering.isDefined,
+ "rdd.map(x => (null, x)).keyOrdering.isDefined"),
+ (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty,
+ "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"),
+ (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined,
+ "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"),
+ (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined,
+ "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined"))
}
-}
+
+ def otherRDDMethodExpectations(rdd: RDD[Int]) = {
+ List((rdd.groupBy(x => x).keyOrdering.isDefined,
+ "rdd.groupBy(x => x).keyOrdering.isDefined"),
+ (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty,
+ "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"),
+ (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined,
+ "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"),
+ (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined,
+ "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"),
+ (rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined,
+ "rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined"),
+ (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined,
+ "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined"))
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
new file mode 100644
index 0000000000..5d15a68ac7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer;
+
+import java.io.NotSerializableException
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkException
+import org.apache.spark.SharedSparkContext
+
+/* A trivial (but unserializable) container for trivial functions */
+class UnserializableClass {
+ def op[T](x: T) = x.toString
+
+ def pred[T](x: T) = x.toString.length % 2 == 0
+}
+
+class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {
+
+ def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
+
+ test("throws expected serialization exceptions on actions") {
+ val (data, uc) = fixture
+
+ val ex = intercept[SparkException] {
+ data.map(uc.op(_)).count
+ }
+
+ assert(ex.getMessage.contains("Task not serializable"))
+ }
+
+ // There is probably a cleaner way to eliminate boilerplate here, but we're
+ // iterating over a map from transformation names to functions that perform that
+ // transformation on a given RDD, creating one test case for each
+
+ for (transformation <-
+ Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
+ "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
+ "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
+ "mapPartitionsWithContext" -> xmapPartitionsWithContext _,
+ "filterWith" -> xfilterWith _)) {
+ val (name, xf) = transformation
+
+ test(s"$name transformations throw proactive serialization exceptions") {
+ val (data, uc) = fixture
+
+ val ex = intercept[SparkException] {
+ xf(data, uc)
+ }
+
+ assert(ex.getMessage.contains("Task not serializable"),
+ s"RDD.$name doesn't proactively throw NotSerializableException")
+ }
+ }
+
+ private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.map(y=>uc.op(y))
+ private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
+ private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.flatMap(y=>Seq(uc.op(y)))
+ private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.filter(y=>uc.pred(y))
+ private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
+ private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitions(_.map(y=>uc.op(y)))
+ private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
+ private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
+
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 4709a62381..e05db236ad 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -532,7 +532,10 @@ abstract class DStream[T: ClassTag] (
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
- new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
+ // because the DStream is reachable from the outer object here, and because
+ // DStreams can't be serialized with closures, we can't proactively check
+ // it for serializability and so we pass the optional false to SparkContext.clean
+ new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
}
/**
@@ -540,7 +543,10 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
+ // because the DStream is reachable from the outer object here, and because
+ // DStreams can't be serialized with closures, we can't proactively check
+ // it for serializability and so we pass the optional false to SparkContext.clean
+ transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
}
/**
@@ -548,7 +554,10 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- val cleanedF = context.sparkContext.clean(transformFunc)
+ // because the DStream is reachable from the outer object here, and because
+ // DStreams can't be serialized with closures, we can't proactively check
+ // it for serializability and so we pass the optional false to SparkContext.clean
+ val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
@@ -563,7 +572,10 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc)
+ // because the DStream is reachable from the outer object here, and because
+ // DStreams can't be serialized with closures, we can't proactively check
+ // it for serializability and so we pass the optional false to SparkContext.clean
+ val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}
@@ -574,7 +586,10 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc)
+ // because the DStream is reachable from the outer object here, and because
+ // DStreams can't be serialized with closures, we can't proactively check
+ // it for serializability and so we pass the optional false to SparkContext.clean
+ val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]