aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala30
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala21
2 files changed, 41 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b3b60578c9..8baf199f21 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -717,7 +717,8 @@ abstract class RDD[T: ClassTag](
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
- val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ val cleanF = sc.clean(f)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
@@ -741,9 +742,11 @@ abstract class RDD[T: ClassTag](
def mapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = withScope {
+ val cleanF = sc.clean(f)
+ val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
- val a = constructA(index)
- iter.map(t => f(t, a))
+ val a = cleanA(index)
+ iter.map(t => cleanF(t, a))
}, preservesPartitioning)
}
@@ -756,9 +759,11 @@ abstract class RDD[T: ClassTag](
def flatMapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = withScope {
+ val cleanF = sc.clean(f)
+ val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
+ val a = cleanA(index)
+ iter.flatMap(t => cleanF(t, a))
}, preservesPartitioning)
}
@@ -769,9 +774,11 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
+ val cleanF = sc.clean(f)
+ val cleanA = sc.clean(constructA)
mapPartitionsWithIndex { (index, iter) =>
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
+ val a = cleanA(index)
+ iter.map(t => {cleanF(t, a); t})
}
}
@@ -782,9 +789,11 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope {
+ val cleanP = sc.clean(p)
+ val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
- val a = constructA(index)
- iter.filter(t => p(t, a))
+ val a = cleanA(index)
+ iter.filter(t => cleanP(t, a))
}, preservesPartitioning = true)
}
@@ -901,7 +910,8 @@ abstract class RDD[T: ClassTag](
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
- filter(f.isDefinedAt).map(f)
+ val cleanF = sc.clean(f)
+ filter(cleanF.isDefinedAt).map(cleanF)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 446c3f24a7..e41f6ee277 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io.NotSerializableException
+import java.util.Random
import org.scalatest.FunSuite
@@ -92,6 +93,11 @@ class ClosureCleanerSuite extends FunSuite {
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) }
+ expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
@@ -260,6 +266,21 @@ private object TestUserClosuresActuallyCleaned {
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
}
+ def testFlatMapWith(rdd: RDD[Int]): Unit = {
+ rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count()
+ }
+ def testMapWith(rdd: RDD[Int]): Unit = {
+ rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count()
+ }
+ def testFilterWith(rdd: RDD[Int]): Unit = {
+ rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count()
+ }
+ def testForEachWith(rdd: RDD[Int]): Unit = {
+ rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return }
+ }
+ def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = {
+ rdd.mapPartitionsWithContext { (_, it) => return; it }.count()
+ }
def testZipPartitions2(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
}