aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-03-13 17:17:34 -0500
committerStephen Haberman <stephen@exigencecorp.com>2013-03-13 17:17:34 -0500
commit63fe22558791e6a511eb1f48efb88e2afdf77659 (patch)
treefd0e482ff43d984cb280acd0647e9398aac6934b
parentcbf8f0d4dda41ffd45855eab8401fda9b64168cd (diff)
downloadspark-63fe22558791e6a511eb1f48efb88e2afdf77659.tar.gz
spark-63fe22558791e6a511eb1f48efb88e2afdf77659.tar.bz2
spark-63fe22558791e6a511eb1f48efb88e2afdf77659.zip
Simplify SubtractedRDD in preparation from subtractByKey.
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala2
-rw-r--r--core/src/main/scala/spark/RDD.scala24
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala78
3 files changed, 58 insertions, 46 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e7408e4352..1bd1741a71 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -639,6 +639,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
}, true)
}
+
+ // def subtractByKey(other: RDD[K]): RDD[(K,V)] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
}
private[spark]
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 584efa8adf..3451136fd4 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -408,8 +408,24 @@ abstract class RDD[T: ClassManifest](
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
- def subtract(other: RDD[T]): RDD[T] =
- subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
+ def subtract(other: RDD[T]): RDD[T] = {
+ // If we do have a partitioner, our T is really (K, V), and we'll need to
+ // unwrap the (T, null) that subtract does to get back to the K
+ val rdd = subtract(other, partitioner match {
+ case None => new HashPartitioner(partitions.size)
+ case Some(p) => new Partitioner() {
+ override def numPartitions = p.numPartitions
+ override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+ }
+ })
+ // Hacky, but if we did have a partitioner, we can keep using it
+ new RDD[T](rdd) {
+ override def getPartitions = rdd.partitions
+ override def getDependencies = rdd.dependencies
+ override def compute(split: Partition, context: TaskContext) = rdd.compute(split, context)
+ override val partitioner = RDD.this.partitioner
+ }
+ }
/**
* Return an RDD with the elements from `this` that are not in `other`.
@@ -420,7 +436,9 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
- def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+ def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+ new SubtractedRDD[T, Any](this.map((_, null)), other.map((_, null)), p).keys
+ }
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 43ec90cac5..1bc84f7e1e 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -1,7 +1,8 @@
package spark.rdd
-import java.util.{HashSet => JHashSet}
+import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
import spark.RDD
import spark.Partitioner
import spark.Dependency
@@ -27,39 +28,20 @@ import spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[T: ClassManifest](
- @transient var rdd1: RDD[T],
- @transient var rdd2: RDD[T],
- part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
+ @transient var rdd1: RDD[(K, V)],
+ @transient var rdd2: RDD[(K, V)],
+ part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
- if (rdd.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + rdd)
- new OneToOneDependency(rdd)
- } else {
- logInfo("Adding shuffle dependency with " + rdd)
- val mapSideCombinedRDD = rdd.mapPartitions(i => {
- val set = new JHashSet[T]()
- while (i.hasNext) {
- set.add(i.next)
- }
- set.iterator
- }, true)
- // ShuffleDependency requires a tuple (k, v), which it will partition by k.
- // We need this to partition to map to the same place as the k for
- // OneToOneDependency, which means:
- // - for already-tupled RDD[(A, B)], into getPartition(a)
- // - for non-tupled RDD[C], into getPartition(c)
- val part2 = new Partitioner() {
- def numPartitions = part.numPartitions
- def getPartition(key: Any) = key match {
- case (k, v) => part.getPartition(k)
- case k => part.getPartition(k)
- }
- }
- new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
- }
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ new OneToOneDependency(rdd)
+ } else {
+ logInfo("Adding shuffle dependency with " + rdd)
+ new ShuffleDependency(rdd, part)
+ }
}
}
@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
override val partitioner = Some(part)
- override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+ override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val set = new JHashSet[T]
- def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+ val map = new JHashMap[K, ArrayBuffer[V]]
+ def getSeq(k: K): ArrayBuffer[V] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = new ArrayBuffer[V]()
+ map.put(k, seq)
+ seq
+ }
+ }
+ def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
- for (k <- rdd.iterator(itsSplit, context))
- op(k.asInstanceOf[T])
+ for (t <- rdd.iterator(itsSplit, context))
+ op(t.asInstanceOf[(K, V)])
case ShuffleCoGroupSplitDep(shuffleId) =>
- for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
- op(k.asInstanceOf[T])
+ for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ op(t.asInstanceOf[(K, V)])
}
- // the first dep is rdd1; add all keys to the set
- integrate(partition.deps(0), set.add)
- // the second dep is rdd2; remove all of its keys from the set
- integrate(partition.deps(1), set.remove)
- set.iterator
+ // the first dep is rdd1; add all values to the map
+ integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+ // the second dep is rdd2; remove all of its keys
+ integrate(partition.deps(1), t => map.remove(t._1) )
+ map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
}
override def clearDependencies() {