diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-03-23 07:13:21 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-03-23 07:13:21 -0700 |
commit | fd53f2fc7be5faa5d7b1aa58226f7f74f44247d2 (patch) | |
tree | fdd00a72dd39ffaffdbbad818d20927fc2776422 | |
parent | 4c5efcf600f75d3020d4e83a27c072e4e77607ac (diff) | |
parent | ab33e27cc9769157c62d940124b5091301bdc69a (diff) | |
download | spark-fd53f2fc7be5faa5d7b1aa58226f7f74f44247d2.tar.gz spark-fd53f2fc7be5faa5d7b1aa58226f7f74f44247d2.tar.bz2 spark-fd53f2fc7be5faa5d7b1aa58226f7f74f44247d2.zip |
Merge pull request #510 from markhamstra/WithThing
mapWith, flatMapWith and filterWith
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 66 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 60 |
2 files changed, 125 insertions, 1 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 9bd8a0f98d..ed39732f13 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -365,6 +365,62 @@ abstract class RDD[T: ClassManifest]( new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) /** + * Maps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => U): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.map(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * FlatMaps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) + (f:(T, A) => Seq[U]): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val a = constructA(index) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * Applies f to each element of this RDD, where f takes an additional parameter of type A. + * This additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def foreachWith[A: ClassManifest](constructA: Int => A) + (f:(T, A) => Unit) { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.map(t => {f(t, a); t}) + } + (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + } + + /** + * Filters this RDD with p, where p takes an additional parameter of type A. This + * additional parameter is produced by constructA, which is called in each + * partition with the index of that partition. + */ + def filterWith[A: ClassManifest](constructA: Int => A) + (p:(T, A) => Boolean): RDD[T] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructA(index) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + } + + /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of * partitions* and the *same number of elements in each partition* (e.g. one was made through @@ -383,6 +439,14 @@ abstract class RDD[T: ClassManifest]( } /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartition(f: Iterator[T] => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) + } + + /** * Return an array that contains all of the elements in this RDD. */ def collect(): Array[T] = { @@ -404,7 +468,7 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 0ff4a47998..53635b1de6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -208,4 +208,64 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(prunedData.size === 1) assert(prunedData(0) === 10) } + + test("mapWith") { + import java.util.Random + sc = new SparkContext("local", "test") + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.mapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextDouble * t}.collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(2) === prn42_3) + assert(randoms(5) === prn43_3) + } + + test("flatMapWith") { + import java.util.Random + sc = new SparkContext("local", "test") + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.flatMapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => + val random = prng.nextDouble() + Seq(random * t, random * t * 10)}. + collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(5) === prn42_3 * 10) + assert(randoms(11) === prn43_3 * 10) + } + + test("filterWith") { + import java.util.Random + sc = new SparkContext("local", "test") + val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + val sample = ints.filterWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextInt(3) == 0}. + collect() + val checkSample = { + val prng42 = new Random(42) + val prng43 = new Random(43) + Array(1, 2, 3, 4, 5, 6).filter{i => + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} + } + assert(sample.size === checkSample.size) + for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) + } } |