aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala94
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala68
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala8
8 files changed, 218 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 76305237b0..545807ffbc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1002,7 +1002,9 @@ class SparkContext(config: SparkConf) extends Logging {
require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
}
val callSite = getCallSite
- val cleanedFunc = clean(func)
+ // There's no need to check this function for serializability,
+ // since it will be run right away.
+ val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
@@ -1135,14 +1137,18 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
-
+
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
+ *
+ * @param f closure to be cleaned and optionally serialized
+ * @param captureNow whether or not to serialize this closure and capture any free
+ * variables immediately; defaults to true. If this is set and f is not serializable,
+ * it will raise an exception.
*/
- private[spark] def clean[F <: AnyRef](f: F): F = {
- ClosureCleaner.clean(f)
- f
+ private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = {
+ ClosureCleaner.clean(f, captureNow)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 3437b2cac1..e363ea777d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -660,14 +660,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- sc.runJob(this, (iter: Iterator[T]) => f(iter))
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index cdbbc65292..e474b1a850 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -22,10 +22,14 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set
+import scala.reflect.ClassTag
+
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+import org.apache.spark.SparkException
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -101,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- def clean(func: AnyRef) {
+ def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
@@ -150,6 +154,21 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}
+
+ if (captureNow) {
+ cloneViaSerializing(func)
+ } else {
+ func
+ }
+ }
+
+ private def cloneViaSerializing[T: ClassTag](func: T): T = {
+ try {
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ serializer.deserialize[T](serializer.serialize[T](func))
+ } catch {
+ case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString)
+ }
}
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 12dbebcb28..4f9300419e 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
- test("failure because task closure is not serializable") {
+ test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable
@@ -118,6 +118,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
+ FailureSuiteState.clear()
+ }
+
+ test("failure because closure in early-stage task is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
@@ -125,6 +132,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))
+ FailureSuiteState.clear()
+ }
+
+ test("failure because closure in foreach task is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
@@ -135,5 +149,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
+
// TODO: Need to add tests with shuffle fetch failures.
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
new file mode 100644
index 0000000000..76662264e7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer;
+
+import java.io.NotSerializableException
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkException
+import org.apache.spark.SharedSparkContext
+
+/* A trivial (but unserializable) container for trivial functions */
+class UnserializableClass {
+ def op[T](x: T) = x.toString
+
+ def pred[T](x: T) = x.toString.length % 2 == 0
+}
+
+class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {
+
+ def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
+
+ test("throws expected serialization exceptions on actions") {
+ val (data, uc) = fixture
+
+ val ex = intercept[SparkException] {
+ data.map(uc.op(_)).count
+ }
+
+ assert(ex.getMessage.matches(".*Task not serializable.*"))
+ }
+
+ // There is probably a cleaner way to eliminate boilerplate here, but we're
+ // iterating over a map from transformation names to functions that perform that
+ // transformation on a given RDD, creating one test case for each
+
+ for (transformation <-
+ Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _,
+ "mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
+ "mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) {
+ val (name, xf) = transformation
+
+ test(s"$name transformations throw proactive serialization exceptions") {
+ val (data, uc) = fixture
+
+ val ex = intercept[SparkException] {
+ xf(data, uc)
+ }
+
+ assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException")
+ }
+ }
+
+ def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.map(y => uc.op(y))
+
+ def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapWith(x => x.toString)((x,y) => x + uc.op(y))
+
+ def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.flatMap(y=>Seq(uc.op(y)))
+
+ def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.filter(y=>uc.pred(y))
+
+ def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.filterWith(x => x.toString)((x,y) => uc.pred(y))
+
+ def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitions(_.map(y => uc.op(y)))
+
+ def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))
+
+ def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 439e5644e2..c635da6cac 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -50,6 +50,27 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}
+
+ test("capturing free variables in closures at RDD definition") {
+ val obj = new TestCaptureVarClass()
+ val (ones, onesPlusZeroes) = obj.run()
+
+ assert(ones === onesPlusZeroes)
+ }
+
+ test("capturing free variable fields in closures at RDD definition") {
+ val obj = new TestCaptureFieldClass()
+ val (ones, onesPlusZeroes) = obj.run()
+
+ assert(ones === onesPlusZeroes)
+ }
+
+ test("capturing arrays in closures at RDD definition") {
+ val obj = new TestCaptureArrayEltClass()
+ val (observed, expected) = obj.run()
+
+ assert(observed === expected)
+ }
}
// A non-serializable class we create in closures to make sure that we aren't
@@ -143,3 +164,50 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}
+
+class TestCaptureFieldClass extends Serializable {
+ class ZeroBox extends Serializable {
+ var zero = 0
+ }
+
+ def run(): (Int, Int) = {
+ val zb = new ZeroBox
+
+ withSpark(new SparkContext("local", "test")) {sc =>
+ val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
+ val onesPlusZeroes = ones.map(_ + zb.zero)
+
+ zb.zero = 5
+
+ (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
+ }
+ }
+}
+
+class TestCaptureArrayEltClass extends Serializable {
+ def run(): (Int, Int) = {
+ withSpark(new SparkContext("local", "test")) {sc =>
+ val rdd = sc.parallelize(1 to 10)
+ val data = Array(1, 2, 3)
+ val expected = data(0)
+ val mapped = rdd.map(x => data(0))
+ data(0) = 4
+ (mapped.first, expected)
+ }
+ }
+}
+
+class TestCaptureVarClass extends Serializable {
+ def run(): (Int, Int) = {
+ var zero = 0
+
+ withSpark(new SparkContext("local", "test")) {sc =>
+ val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
+ val onesPlusZeroes = ones.map(_ + zero)
+
+ zero = 5
+
+ (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 28d34dd9a1..c65e36636f 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
- graph.triplets.map { et =>
+ graph.triplets.collect.map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index d043200f71..4759b629a9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -539,7 +539,7 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
+ transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
}
/**
@@ -547,7 +547,7 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- val cleanedF = context.sparkContext.clean(transformFunc)
+ val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
@@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc)
+ val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}
@@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc)
+ val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]