diff options
author | William Benton <willb@redhat.com> | 2014-06-29 23:27:34 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-06-29 23:27:34 -0700 |
commit | a484030dae9d0d7e4b97cc6307e9e928c07490dc (patch) | |
tree | 8077c51ced4a8a14f93a969798706ff39d948dd2 /core | |
parent | 66135a341d9f8baecc149d13ae5511f14578c395 (diff) | |
download | spark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.tar.gz spark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.tar.bz2 spark-a484030dae9d0d7e4b97cc6307e9e928c07490dc.zip |
SPARK-897: preemptively serialize closures
These commits cause `ClosureCleaner.clean` to attempt to serialize the cleaned closure with the default closure serializer and throw a `SparkException` if doing so fails. This behavior is enabled by default but can be disabled at individual callsites of `SparkContext.clean`.
Commit 98e01ae8 fixes some no-op assertions in `GraphSuite` that this work exposed; I'm happy to put that in a separate PR if that would be more appropriate.
Author: William Benton <willb@redhat.com>
Closes #143 from willb/spark-897 and squashes the following commits:
bceab8a [William Benton] Commented DStream corner cases for serializability checking.
64d04d2 [William Benton] FailureSuite now checks both messages and causes.
3b3f74a [William Benton] Stylistic and doc cleanups.
b215dea [William Benton] Fixed spurious failures in ImplicitOrderingSuite
be1ecd6 [William Benton] Don't check serializability of DStream transforms.
abe816b [William Benton] Make proactive serializability checking optional.
5bfff24 [William Benton] Adds proactive closure-serializablilty checking
ed2ccf0 [William Benton] Test cases for SPARK-897.
Diffstat (limited to 'core')
5 files changed, 176 insertions, 31 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f9476ff826..8819e73d17 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1203,9 +1203,17 @@ class SparkContext(config: SparkConf) extends Logging { /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) + * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively + * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt> + * if not. + * + * @param f the closure to clean + * @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability + * @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not + * serializable */ - private[spark] def clean[F <: AnyRef](f: F): F = { - ClosureCleaner.clean(f) + private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { + ClosureCleaner.clean(f, checkSerializable) f } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 4916d9b86c..e3f52f6ff1 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.Set import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -101,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef) { + def clean(func: AnyRef, checkSerializable: Boolean = true) { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -153,6 +153,18 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(func, outer) } + + if (checkSerializable) { + ensureSerializable(func) + } + } + + private def ensureSerializable(func: AnyRef) { + try { + SparkEnv.get.closureSerializer.newInstance().serialize(func) + } catch { + case ex: Exception => throw new SparkException("Task not serializable", ex) + } } private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 12dbebcb28..e755d2e309 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -22,6 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.util.NonSerializable +import java.io.NotSerializableException + // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully // be copied for each task, so there's no other way for them to share state. @@ -102,7 +104,8 @@ class FailureSuite extends FunSuite with LocalSparkContext { results.collect() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) + assert(thrown.getMessage.contains("NotSerializableException") || + thrown.getCause.getClass === classOf[NotSerializableException]) FailureSuiteState.clear() } @@ -116,21 +119,24 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => a).count() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) + assert(thrown.getMessage.contains("NotSerializableException") || + thrown.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in an earlier stage val thrown1 = intercept[SparkException] { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() } assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException")) + assert(thrown1.getMessage.contains("NotSerializableException") || + thrown1.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { sc.parallelize(1 to 10, 2).foreach(x => println(a)) } assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException")) + assert(thrown2.getMessage.contains("NotSerializableException") || + thrown2.getCause.getClass === classOf[NotSerializableException]) FailureSuiteState.clear() } diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 4bd8891356..8e4a9e2c9f 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -19,9 +19,29 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { + // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. + test("basic inference of Orderings"){ + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(1 to 10) + + // These RDD methods are in the companion object so that the unserializable ScalaTest Engine + // won't be reachable from the closure object + + // Infer orderings after basic maps to particular types + val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd) + basicMapExpectations.map({case (met, explain) => assert(met, explain)}) + + // Infer orderings for other RDD methods + val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd) + otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)}) + } +} + +private object ImplicitOrderingSuite { class NonOrderedClass {} class ComparableClass extends Comparable[ComparableClass] { @@ -31,27 +51,36 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { class OrderedClass extends Ordered[OrderedClass] { override def compare(o: OrderedClass): Int = ??? } - - // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. - test("basic inference of Orderings"){ - sc = new SparkContext("local", "test") - val rdd = sc.parallelize(1 to 10) - - // Infer orderings after basic maps to particular types - assert(rdd.map(x => (x, x)).keyOrdering.isDefined) - assert(rdd.map(x => (1, x)).keyOrdering.isDefined) - assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined) - assert(rdd.map(x => (null, x)).keyOrdering.isDefined) - assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty) - assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined) - assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined) - - // Infer orderings for other RDD methods - assert(rdd.groupBy(x => x).keyOrdering.isDefined) - assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty) - assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined) - assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined) - assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined) - assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined) + + def basicMapExpectations(rdd: RDD[Int]) = { + List((rdd.map(x => (x, x)).keyOrdering.isDefined, + "rdd.map(x => (x, x)).keyOrdering.isDefined"), + (rdd.map(x => (1, x)).keyOrdering.isDefined, + "rdd.map(x => (1, x)).keyOrdering.isDefined"), + (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, + "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"), + (rdd.map(x => (null, x)).keyOrdering.isDefined, + "rdd.map(x => (null, x)).keyOrdering.isDefined"), + (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, + "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"), + (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, + "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"), + (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, + "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } -} + + def otherRDDMethodExpectations(rdd: RDD[Int]) = { + List((rdd.groupBy(x => x).keyOrdering.isDefined, + "rdd.groupBy(x => x).keyOrdering.isDefined"), + (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, + "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"), + (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, + "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"), + (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined, + "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"), + (rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined, + "rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined"), + (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined, + "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined")) + } +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala new file mode 100644 index 0000000000..5d15a68ac7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.NotSerializableException + +import org.scalatest.FunSuite + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkException +import org.apache.spark.SharedSparkContext + +/* A trivial (but unserializable) container for trivial functions */ +class UnserializableClass { + def op[T](x: T) = x.toString + + def pred[T](x: T) = x.toString.length % 2 == 0 +} + +class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { + + def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + + test("throws expected serialization exceptions on actions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + data.map(uc.op(_)).count + } + + assert(ex.getMessage.contains("Task not serializable")) + } + + // There is probably a cleaner way to eliminate boilerplate here, but we're + // iterating over a map from transformation names to functions that perform that + // transformation on a given RDD, creating one test case for each + + for (transformation <- + Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, + "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _, + "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, + "mapPartitionsWithContext" -> xmapPartitionsWithContext _, + "filterWith" -> xfilterWith _)) { + val (name, xf) = transformation + + test(s"$name transformations throw proactive serialization exceptions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + xf(data, uc) + } + + assert(ex.getMessage.contains("Task not serializable"), + s"RDD.$name doesn't proactively throw NotSerializableException") + } + } + + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.map(y=>uc.op(y)) + private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.flatMap(y=>Seq(uc.op(y))) + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filter(y=>uc.pred(y)) + private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitions(_.map(y=>uc.op(y))) + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) + private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) + +} |