diff options
Diffstat (limited to 'core/src/main/scala/spark/rdd/CoGroupedRDD.scala')
-rw-r--r-- | core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 59 |
1 files changed, 26 insertions, 33 deletions
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 8fafd27bb6..5200fb6b65 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -5,7 +5,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} +import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -14,13 +14,13 @@ private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep( rdd: RDD[_], splitIndex: Int, - var split: Split + var split: Partition ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) + split = rdd.partitions(splitIndex) oos.defaultWriteObject() } } @@ -28,7 +28,7 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] -class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { +class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } @@ -40,50 +40,47 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { - val aggr = new CoGroupAggregator + private val aggr = new CoGroupAggregator - @transient var deps_ = { - val deps = new ArrayBuffer[Dependency[_]] - for ((rdd, index) <- rdds.zipWithIndex) { + override def getDependencies: Seq[Dependency[_]] = { + rdds.map { rdd => if (rdd.partitioner == Some(part)) { logInfo("Adding one-to-one dependency with " + rdd) - deps += new OneToOneDependency(rdd) + new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) + new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } - deps.toList } - override def getDependencies = deps_ - - @transient var splits_ : Array[Split] = { - val array = new Array[Split](part.numPartitions) + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](part.numPartitions) for (i <- 0 until array.size) { - array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => + // Each CoGroupPartition will have a dependency per contributing RDD + array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => + // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep + new ShuffleCoGroupSplitDep(s.shuffleId) case _ => - new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } }.toList) } array } - override def getSplits = splits_ - override val partitioner = Some(part) - override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { - val split = s.asInstanceOf[CoGroupSplit] + override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size + // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { val seq = map.get(k) @@ -96,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v @@ -104,21 +101,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle - def mergePair(pair: (K, Seq[Any])) { - val mySeq = getSeq(pair._1) - for (v <- pair._2) - mySeq(depNum) += v - } val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) + for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) { + getSeq(k)(depNum) ++= vs + } } } JavaConversions.mapAsScalaMap(map).iterator } override def clearDependencies() { - deps_ = null - splits_ = null + super.clearDependencies() rdds = null } } |