aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-05-05 09:37:04 -0700
committerAndrew Or <andrew@databricks.com>2015-05-05 09:37:04 -0700
commit1fdabf8dcdb31391fec3952d312eb0ac59ece43b (patch)
treed098993869ad6b053228bec7d1907c474b1fadd0 /core/src
parentd4cb38aeb7412a353c6cbca2a9b8f9729afbaba7 (diff)
downloadspark-1fdabf8dcdb31391fec3952d312eb0ac59ece43b.tar.gz
spark-1fdabf8dcdb31391fec3952d312eb0ac59ece43b.tar.bz2
spark-1fdabf8dcdb31391fec3952d312eb0ac59ece43b.zip
[SPARK-7237] Many user provided closures are not actually cleaned
Note: ~140 lines are tests. In a nutshell, we never cleaned closures the user provided through the following operations: - sortBy - keyBy - mapPartitions - mapPartitionsWithIndex - aggregateByKey - foldByKey - foreachAsync - one of the aliases for runJob - runApproximateJob For more details on a reproduction and why they were not cleaned, please see [SPARK-7237](https://issues.apache.org/jira/browse/SPARK-7237). Author: Andrew Or <andrew@databricks.com> Closes #5787 from andrewor14/clean-more and squashes the following commits: 2f1f476 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more 7265865 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more df3caa3 [Andrew Or] Address comments 7a3cc80 [Andrew Or] Merge branch 'master' of github.com:apache/spark into clean-more 6498f44 [Andrew Or] Add missing test for groupBy e83699e [Andrew Or] Clean one more 8ac3074 [Andrew Or] Prevent NPE in tests when CC is used outside of an app 9ac5f9b [Andrew Or] Clean closures that are not currently cleaned 19e33b4 [Andrew Or] Add tests for all public RDD APIs that take in closures
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala148
5 files changed, 174 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 7ebee99912..00eb432912 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1676,7 +1676,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
+ val cleanedFunc = clean(func)
+ runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal)
}
/**
@@ -1730,7 +1731,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
+ val cleanedFunc = clean(func)
+ val result = dagScheduler.runApproximateJob(rdd, cleanedFunc, evaluator, callSite, timeout,
localProperties.get)
logInfo(
"Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s")
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 93d338fe05..a6d5d2c94e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -131,7 +131,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
- combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
+ // We will clean the combiner closure later in `combineByKey`
+ val cleanedSeqOp = self.context.clean(seqOp)
+ combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner)
}
/**
@@ -179,7 +181,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
- combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
+ val cleanedFunc = self.context.clean(func)
+ combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 7f7c7ed144..b3b60578c9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -678,9 +678,13 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
- f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope {
- val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
- new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
+ f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = withScope {
+ val cleanedF = sc.clean(f)
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
+ preservesPartitioning)
}
/**
@@ -693,8 +697,11 @@ abstract class RDD[T: ClassTag](
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
- val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
- new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
+ val cleanedF = sc.clean(f)
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
+ preservesPartitioning)
}
/**
@@ -1406,7 +1413,8 @@ abstract class RDD[T: ClassTag](
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: T => K): RDD[(K, T)] = withScope {
- map(x => (f(x), x))
+ val cleanedF = sc.clean(f)
+ map(x => (cleanedF(x), x))
}
/** A private method for tests, to look at the contents of each partition */
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 4ac0382d80..19fe6cb9de 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -312,7 +312,9 @@ private[spark] object ClosureCleaner extends Logging {
private def ensureSerializable(func: AnyRef) {
try {
- SparkEnv.get.closureSerializer.newInstance().serialize(func)
+ if (SparkEnv.get != null) {
+ SparkEnv.get.closureSerializer.newInstance().serialize(func)
+ }
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
@@ -347,6 +349,9 @@ private[spark] object ClosureCleaner extends Logging {
}
}
+private[spark] class ReturnStatementInClosureException
+ extends SparkException("Return statements aren't allowed in Spark closures")
+
private class ReturnStatementFinder extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
@@ -354,7 +359,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) {
new MethodVisitor(ASM4) {
override def visitTypeInsn(op: Int, tp: String) {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
- throw new SparkException("Return statements aren't allowed in Spark closures")
+ throw new ReturnStatementInClosureException
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index ff1bfe0774..446c3f24a7 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -17,10 +17,14 @@
package org.apache.spark.util
+import java.io.NotSerializableException
+
import org.scalatest.FunSuite
import org.apache.spark.LocalSparkContext._
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.{TaskContext, SparkContext, SparkException}
+import org.apache.spark.partial.CountEvaluator
+import org.apache.spark.rdd.RDD
class ClosureCleanerSuite extends FunSuite {
test("closures inside an object") {
@@ -52,17 +56,66 @@ class ClosureCleanerSuite extends FunSuite {
}
test("toplevel return statements in closures are identified at cleaning time") {
- val ex = intercept[SparkException] {
+ intercept[ReturnStatementInClosureException] {
TestObjectWithBogusReturns.run()
}
-
- assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures"))
}
test("return statements from named functions nested in closures don't raise exceptions") {
val result = TestObjectWithNestedReturns.run()
assert(result === 1)
}
+
+ test("user provided closures are actually cleaned") {
+
+ // We use return statements as an indication that a closure is actually being cleaned
+ // We expect closure cleaner to find the return statements in the user provided closures
+ def expectCorrectException(body: => Unit): Unit = {
+ try {
+ body
+ } catch {
+ case rse: ReturnStatementInClosureException => // Success!
+ case e @ (_: NotSerializableException | _: SparkException) =>
+ fail(s"Expected ReturnStatementInClosureException, but got $e.\n" +
+ "This means the closure provided by user is not actually cleaned.")
+ }
+ }
+
+ withSpark(new SparkContext("local", "test")) { sc =>
+ val rdd = sc.parallelize(1 to 10)
+ val pairRdd = rdd.map { i => (i, i) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMap(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMap(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFilter(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testSortBy(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testGroupBy(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testForeach(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartition(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testReduce(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testTreeReduce(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFold(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testAggregate(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testTreeAggregate(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testCombineByKey(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testForeachPartitionAsync(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob1(sc) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testRunJob2(sc) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testRunApproximateJob(sc) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) }
+ }
+ }
}
// A non-serializable class we create in closures to make sure that we aren't
@@ -187,3 +240,90 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}
+
+/**
+ * Test whether closures passed in through public APIs are actually cleaned.
+ *
+ * We put a return statement in each of these closures as a mechanism to detect whether the
+ * ClosureCleaner actually cleaned our closure. If it did, then it would throw an appropriate
+ * exception explicitly complaining about the return statement. Otherwise, we know the
+ * ClosureCleaner did not actually clean our closure, in which case we should fail the test.
+ */
+private object TestUserClosuresActuallyCleaned {
+ def testMap(rdd: RDD[Int]): Unit = { rdd.map { _ => return; 0 }.count() }
+ def testFlatMap(rdd: RDD[Int]): Unit = { rdd.flatMap { _ => return; Seq() }.count() }
+ def testFilter(rdd: RDD[Int]): Unit = { rdd.filter { _ => return; true }.count() }
+ def testSortBy(rdd: RDD[Int]): Unit = { rdd.sortBy { _ => return; 1 }.count() }
+ def testKeyBy(rdd: RDD[Int]): Unit = { rdd.keyBy { _ => return; 1 }.count() }
+ def testGroupBy(rdd: RDD[Int]): Unit = { rdd.groupBy { _ => return; 1 }.count() }
+ def testMapPartitions(rdd: RDD[Int]): Unit = { rdd.mapPartitions { it => return; it }.count() }
+ def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
+ rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
+ }
+ def testZipPartitions2(rdd: RDD[Int]): Unit = {
+ rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
+ }
+ def testZipPartitions3(rdd: RDD[Int]): Unit = {
+ rdd.zipPartitions(rdd, rdd) { case (it1, it2, it3) => return; it1 }.count()
+ }
+ def testZipPartitions4(rdd: RDD[Int]): Unit = {
+ rdd.zipPartitions(rdd, rdd, rdd) { case (it1, it2, it3, it4) => return; it1 }.count()
+ }
+ def testForeach(rdd: RDD[Int]): Unit = { rdd.foreach { _ => return } }
+ def testForeachPartition(rdd: RDD[Int]): Unit = { rdd.foreachPartition { _ => return } }
+ def testReduce(rdd: RDD[Int]): Unit = { rdd.reduce { case (_, _) => return; 1 } }
+ def testTreeReduce(rdd: RDD[Int]): Unit = { rdd.treeReduce { case (_, _) => return; 1 } }
+ def testFold(rdd: RDD[Int]): Unit = { rdd.fold(0) { case (_, _) => return; 1 } }
+ def testAggregate(rdd: RDD[Int]): Unit = {
+ rdd.aggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
+ }
+ def testTreeAggregate(rdd: RDD[Int]): Unit = {
+ rdd.treeAggregate(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 })
+ }
+
+ // Test pair RDD functions
+ def testCombineByKey(rdd: RDD[(Int, Int)]): Unit = {
+ rdd.combineByKey(
+ { _ => return; 1 }: Int => Int,
+ { case (_, _) => return; 1 }: (Int, Int) => Int,
+ { case (_, _) => return; 1 }: (Int, Int) => Int
+ ).count()
+ }
+ def testAggregateByKey(rdd: RDD[(Int, Int)]): Unit = {
+ rdd.aggregateByKey(0)({ case (_, _) => return; 1 }, { case (_, _) => return; 1 }).count()
+ }
+ def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } }
+ def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } }
+ def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } }
+ def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } }
+
+ // Test async RDD actions
+ def testForeachAsync(rdd: RDD[Int]): Unit = { rdd.foreachAsync { _ => return } }
+ def testForeachPartitionAsync(rdd: RDD[Int]): Unit = { rdd.foreachPartitionAsync { _ => return } }
+
+ // Test SparkContext runJob
+ def testRunJob1(sc: SparkContext): Unit = {
+ val rdd = sc.parallelize(1 to 10, 10)
+ sc.runJob(rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1 } )
+ }
+ def testRunJob2(sc: SparkContext): Unit = {
+ val rdd = sc.parallelize(1 to 10, 10)
+ sc.runJob(rdd, { iter: Iterator[Int] => return; 1 } )
+ }
+ def testRunApproximateJob(sc: SparkContext): Unit = {
+ val rdd = sc.parallelize(1 to 10, 10)
+ val evaluator = new CountEvaluator(1, 0.5)
+ sc.runApproximateJob(
+ rdd, { (ctx: TaskContext, iter: Iterator[Int]) => return; 1L }, evaluator, 1000)
+ }
+ def testSubmitJob(sc: SparkContext): Unit = {
+ val rdd = sc.parallelize(1 to 10, 10)
+ sc.submitJob(
+ rdd,
+ { _ => return; 1 }: Iterator[Int] => Int,
+ Seq.empty,
+ { case (_, _) => return }: (Int, Int) => Unit,
+ { return }
+ )
+ }
+}