From 54e6fa0563ffa8788ec2fd1b8740445ef3c2ce5a Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 8 May 2015 17:16:38 -0700 Subject: [SPARK-7237] Clean function in several RDD methods Author: tedyu Closes #5959 from ted-yu/master and squashes the following commits: f83d445 [tedyu] Move cleaning outside of mapPartitionsWithIndex 56d7c92 [tedyu] Consolidate import of Random f6014c0 [tedyu] Remove cleaning in RDD#filterWith 36feb6c [tedyu] Try to get correct syntax 55d01eb [tedyu] Try to get correct syntax c2786df [tedyu] Correct syntax d92bfcf [tedyu] Correct syntax in test 164d3e4 [tedyu] Correct variable name 8b50d93 [tedyu] Address Andrew's review comments 0c8d47e [tedyu] Add test for mapWith() 6846e40 [tedyu] Add test for flatMapWith() 6c124a9 [tedyu] Clean function in several RDD methods --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 30 ++++++++++++++-------- .../apache/spark/util/ClosureCleanerSuite.scala | 21 +++++++++++++++ 2 files changed, 41 insertions(+), 10 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b3b60578c9..8baf199f21 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -717,7 +717,8 @@ abstract class RDD[T: ClassTag]( def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + val cleanF = sc.clean(f) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter) new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } @@ -741,9 +742,11 @@ abstract class RDD[T: ClassTag]( def mapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.map(t => f(t, a)) + val a = cleanA(index) + iter.map(t => cleanF(t, a)) }, preservesPartitioning) } @@ -756,9 +759,11 @@ abstract class RDD[T: ClassTag]( def flatMapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.flatMap(t => f(t, a)) + val a = cleanA(index) + iter.flatMap(t => cleanF(t, a)) }, preservesPartitioning) } @@ -769,9 +774,11 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex { (index, iter) => - val a = constructA(index) - iter.map(t => {f(t, a); t}) + val a = cleanA(index) + iter.map(t => {cleanF(t, a); t}) } } @@ -782,9 +789,11 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { + val cleanP = sc.clean(p) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.filter(t => p(t, a)) + val a = cleanA(index) + iter.filter(t => cleanP(t, a)) }, preservesPartitioning = true) } @@ -901,7 +910,8 @@ abstract class RDD[T: ClassTag]( * Return an RDD that contains all matching values by applying `f`. */ def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope { - filter(f.isDefinedAt).map(f) + val cleanF = sc.clean(f) + filter(cleanF.isDefinedAt).map(cleanF) } /** diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 446c3f24a7..e41f6ee277 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.NotSerializableException +import java.util.Random import org.scalatest.FunSuite @@ -92,6 +93,11 @@ class ClosureCleanerSuite extends FunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } @@ -260,6 +266,21 @@ private object TestUserClosuresActuallyCleaned { def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() } + def testFlatMapWith(rdd: RDD[Int]): Unit = { + rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count() + } + def testMapWith(rdd: RDD[Int]): Unit = { + rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count() + } + def testFilterWith(rdd: RDD[Int]): Unit = { + rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count() + } + def testForEachWith(rdd: RDD[Int]): Unit = { + rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return } + } + def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = { + rdd.mapPartitionsWithContext { (_, it) => return; it }.count() + } def testZipPartitions2(rdd: RDD[Int]): Unit = { rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() } -- cgit v1.2.3