aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Hamstra <markhamstra@gmail.com>2013-03-10 22:27:13 -0700
committerMark Hamstra <markhamstra@gmail.com>2013-03-10 22:27:13 -0700
commit1289e7176bc1ad4eb3a7089acb59bcb8220eddab (patch)
treed42b226e16a85f9403c4a8edcf5325612753394e
parentb57df1f5e399c8071871d68d8a0ae0793fc8f731 (diff)
downloadspark-1289e7176bc1ad4eb3a7089acb59bcb8220eddab.tar.gz
spark-1289e7176bc1ad4eb3a7089acb59bcb8220eddab.tar.bz2
spark-1289e7176bc1ad4eb3a7089acb59bcb8220eddab.zip
refactored _With API and added foreachPartition
-rw-r--r--core/src/main/scala/spark/RDD.scala79
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala34
2 files changed, 57 insertions, 56 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 0a901a251d..2ad11bc604 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -365,60 +365,59 @@ abstract class RDD[T: ClassManifest](
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
- * Maps f over this RDD where f takes an additional parameter of type A. This
- * additional parameter is produced by a factory method T => A which is called
- * on each invocation of f. This factory method is produced by the factoryBuilder,
- * an instance of which is constructed in each partition from the partition index
- * and a seed value of type B.
- */
- def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
- factoryBuilder: (Int, B) => (T => A),
- factorySeed: B,
- preservesPartitioning: Boolean = false)
+ * Maps f over this RDD where, f takes an additional parameter of type A. This
+ * additional parameter is produced by constructorOfA, which is called in each
+ * partition with the index of that partition.
+ */
+ def mapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false)
(f:(A, T) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val factory = factoryBuilder(index, factorySeed)
- iter.map(t => f(factory(t), t))
+ val a = constructorOfA(index)
+ iter.map(t => f(a, t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
- /**
- * FlatMaps f over this RDD where f takes an additional parameter of type A. This
- * additional parameter is produced by a factory method T => A which is called
- * on each invocation of f. This factory method is produced by the factoryBuilder,
- * an instance of which is constructed in each partition from the partition index
- * and a seed value of type B.
+ /**
+ * FlatMaps f over this RDD, where f takes an additional parameter of type A. This
+ * additional parameter is produced by constructorOfA, which is called in each
+ * partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
- factoryBuilder: (Int, B) => (T => A),
- factorySeed: B,
- preservesPartitioning: Boolean = false)
+ def flatMapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false)
(f:(A, T) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val factory = factoryBuilder(index, factorySeed)
- iter.flatMap(t => f(factory(t), t))
+ val a = constructorOfA(index)
+ iter.flatMap(t => f(a, t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
+ * Applies f to each element of this RDD, where f takes an additional parameter of type A.
+ * This additional parameter is produced by constructorOfA, which is called in each
+ * partition with the index of that partition.
+ */
+ def foreachWith[A: ClassManifest](constructorOfA: Int => A)
+ (f:(A, T) => Unit) {
+ def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
+ val a = constructorOfA(index)
+ iter.map(t => {f(a, t); t})
+ }
+ (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ }
+
+ /**
* Filters this RDD with p, where p takes an additional parameter of type A. This
- * additional parameter is produced by a factory method T => A which is called
- * on each invocation of p. This factory method is produced by the factoryBuilder,
- * an instance of which is constructed in each partition from the partition index
- * and a seed value of type B.
- */
- def filterWith[A: ClassManifest, B: ClassManifest](
- factoryBuilder: (Int, B) => (T => A),
- factorySeed: B,
- preservesPartitioning: Boolean = false)
+ * additional parameter is produced by constructorOfA, which is called in each
+ * partition with the index of that partition.
+ */
+ def filterWith[A: ClassManifest](constructorOfA: Int => A)
(p:(A, T) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val factory = factoryBuilder(index, factorySeed)
- iter.filter(t => p(factory(t), t))
+ val a = constructorOfA(index)
+ iter.filter(t => p(a, t))
}
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
}
/**
@@ -440,6 +439,14 @@ abstract class RDD[T: ClassManifest](
}
/**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartition(f: Iterator[T] => Unit) {
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
+ }
+
+ /**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): Array[T] = {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 2a182e0d6c..d260191dd7 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -180,21 +180,18 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("mapWith") {
+ import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
- (index: Int, seed: Int) => {
- val prng = new java.util.Random(index + seed)
- (_ => prng.nextDouble)},
- 42)
- {(random: Double, t: Int) => random * t}.
- collect()
+ (index: Int) => new Random(index + 42))
+ {(prng: Random, t: Int) => prng.nextDouble * t}.collect()
val prn42_3 = {
- val prng42 = new java.util.Random(42)
+ val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
- val prng43 = new java.util.Random(43)
+ val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(2) === prn42_3)
@@ -202,21 +199,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("flatMapWith") {
+ import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
- (index: Int, seed: Int) => {
- val prng = new java.util.Random(index + seed)
- (_ => prng.nextDouble)},
- 42)
- {(random: Double, t: Int) => Seq(random * t, random * t * 10)}.
+ (index: Int) => new Random(index + 42))
+ {(prng: Random, t: Int) => {
+ val random = prng.nextDouble()
+ Seq(random * t, random * t * 10)}}.
collect()
val prn42_3 = {
- val prng42 = new java.util.Random(42)
+ val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
- val prng43 = new java.util.Random(43)
+ val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(5) === prn42_3 * 10)
@@ -228,11 +225,8 @@ class RDDSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
- (index: Int, seed: Int) => {
- val prng = new Random(index + seed)
- (_ => prng.nextInt(3))},
- 42)
- {(random: Int, t: Int) => random == 0}.
+ (index: Int) => new Random(index + 42))
+ {(prng: Random, t: Int) => prng.nextInt(3) == 0}.
collect()
val checkSample = {
val prng42 = new Random(42)