diff options
Diffstat (limited to 'core/src/main/scala/spark/rdd/CoGroupedRDD.scala')
-rw-r--r-- | core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 30 |
1 files changed, 12 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 50bec9e63b..de0d9fad88 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -3,21 +3,15 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.Aggregator -import spark.Dependency -import spark.Logging -import spark.OneToOneDependency -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split +import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} +import spark.{Dependency, OneToOneDependency, ShuffleDependency} + private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] +private[spark] class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { - + val aggr = new CoGroupAggregator - + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] @@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } deps.toList } - + @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def splits = splits_ - + override val partitioner = Some(part) - + override def preferredLocations(s: Split) = Nil - - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { + + override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] @@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { // Read them from the parent - for ((k, v) <- rdd.iterator(itsSplit)) { + for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v } } |