aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/rdd/CoGroupedRDD.scala')
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala30
1 files changed, 12 insertions, 18 deletions
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 50bec9e63b..de0d9fad88 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -3,21 +3,15 @@ package spark.rdd
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.Aggregator
-import spark.Dependency
-import spark.Logging
-import spark.OneToOneDependency
-import spark.Partitioner
-import spark.RDD
-import spark.ShuffleDependency
-import spark.SparkEnv
-import spark.Split
+import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, ShuffleDependency}
+
private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
+private[spark]
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
@@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
-
+
val aggr = new CoGroupAggregator
-
+
@transient
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
@@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
deps.toList
}
-
+
@transient
val splits_ : Array[Split] = {
val firstRdd = rdds.head
@@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
override def splits = splits_
-
+
override val partitioner = Some(part)
-
+
override def preferredLocations(s: Split) = Nil
-
- override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
+
+ override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => {
// Read them from the parent
- for ((k, v) <- rdd.iterator(itsSplit)) {
+ for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
}
}