aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/rdd/CoGroupedRDD.scala')
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala108
1 files changed, 108 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
new file mode 100644
index 0000000000..8fa0749184
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -0,0 +1,108 @@
+package spark.rdd
+
+import java.net.URL
+import java.io.EOFException
+import java.io.ObjectInputStream
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.Aggregator
+import spark.Dependency
+import spark.Logging
+import spark.OneToOneDependency
+import spark.Partitioner
+import spark.RDD
+import spark.ShuffleDependency
+import spark.SparkEnv
+import spark.Split
+
+private[spark] sealed trait CoGroupSplitDep extends Serializable
+private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+
+private[spark]
+class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
+ override val index: Int = idx
+ override def hashCode(): Int = idx
+}
+
+private[spark] class CoGroupAggregator
+ extends Aggregator[Any, Any, ArrayBuffer[Any]](
+ { x => ArrayBuffer(x) },
+ { (b, x) => b += x },
+ { (b1, b2) => b1 ++ b2 })
+ with Serializable
+
+class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
+
+ val aggr = new CoGroupAggregator
+
+ @transient
+ override val dependencies = {
+ val deps = new ArrayBuffer[Dependency[_]]
+ for ((rdd, index) <- rdds.zipWithIndex) {
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ deps += new OneToOneDependency(rdd)
+ } else {
+ logInfo("Adding shuffle dependency with " + rdd)
+ deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
+ context.newShuffleId, rdd, aggr, part)
+ }
+ }
+ deps.toList
+ }
+
+ @transient
+ val splits_ : Array[Split] = {
+ val firstRdd = rdds.head
+ val array = new Array[Split](part.numPartitions)
+ for (i <- 0 until array.size) {
+ array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
+ dependencies(j) match {
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
+ case _ =>
+ new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
+ }
+ }.toList)
+ }
+ array
+ }
+
+ override def splits = splits_
+
+ override val partitioner = Some(part)
+
+ override def preferredLocations(s: Split) = Nil
+
+ override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
+ val split = s.asInstanceOf[CoGroupSplit]
+ val numRdds = split.deps.size
+ val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
+ def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
+ map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
+ }
+ for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
+ case NarrowCoGroupSplitDep(rdd, itsSplit) => {
+ // Read them from the parent
+ for ((k, v) <- rdd.iterator(itsSplit)) {
+ getSeq(k.asInstanceOf[K])(depNum) += v
+ }
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ // Read map outputs of shuffle
+ def mergePair(k: K, vs: Seq[Any]) {
+ val mySeq = getSeq(k)
+ for (v <- vs)
+ mySeq(depNum) += v
+ }
+ val fetcher = SparkEnv.get.shuffleFetcher
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, mergePair)
+ }
+ }
+ map.iterator
+ }
+}