diff options
Diffstat (limited to 'core/src/main/scala/spark/rdd/CoGroupedRDD.scala')
-rw-r--r-- | core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala new file mode 100644 index 0000000000..8fa0749184 --- /dev/null +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -0,0 +1,108 @@ +package spark.rdd + +import java.net.URL +import java.io.EOFException +import java.io.ObjectInputStream + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import spark.Aggregator +import spark.Dependency +import spark.Logging +import spark.OneToOneDependency +import spark.Partitioner +import spark.RDD +import spark.ShuffleDependency +import spark.SparkEnv +import spark.Split + +private[spark] sealed trait CoGroupSplitDep extends Serializable +private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep + +private[spark] +class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { + override val index: Int = idx + override def hashCode(): Int = idx +} + +private[spark] class CoGroupAggregator + extends Aggregator[Any, Any, ArrayBuffer[Any]]( + { x => ArrayBuffer(x) }, + { (b, x) => b += x }, + { (b1, b2) => b1 ++ b2 }) + with Serializable + +class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { + + val aggr = new CoGroupAggregator + + @transient + override val dependencies = { + val deps = new ArrayBuffer[Dependency[_]] + for ((rdd, index) <- rdds.zipWithIndex) { + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + deps += new OneToOneDependency(rdd) + } else { + logInfo("Adding shuffle dependency with " + rdd) + deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]]( + context.newShuffleId, rdd, aggr, part) + } + } + deps.toList + } + + @transient + val splits_ : Array[Split] = { + val firstRdd = rdds.head + val array = new Array[Split](part.numPartitions) + for (i <- 0 until array.size) { + array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => + dependencies(j) match { + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep + case _ => + new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep + } + }.toList) + } + array + } + + override def splits = splits_ + + override val partitioner = Some(part) + + override def preferredLocations(s: Split) = Nil + + override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { + val split = s.asInstanceOf[CoGroupSplit] + val numRdds = split.deps.size + val map = new HashMap[K, Seq[ArrayBuffer[Any]]] + def getSeq(k: K): Seq[ArrayBuffer[Any]] = { + map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + } + for ((dep, depNum) <- split.deps.zipWithIndex) dep match { + case NarrowCoGroupSplitDep(rdd, itsSplit) => { + // Read them from the parent + for ((k, v) <- rdd.iterator(itsSplit)) { + getSeq(k.asInstanceOf[K])(depNum) += v + } + } + case ShuffleCoGroupSplitDep(shuffleId) => { + // Read map outputs of shuffle + def mergePair(k: K, vs: Seq[Any]) { + val mySeq = getSeq(k) + for (v <- vs) + mySeq(depNum) += v + } + val fetcher = SparkEnv.get.shuffleFetcher + fetcher.fetch[K, Seq[Any]](shuffleId, split.index, mergePair) + } + } + map.iterator + } +} |