aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala88
1 files changed, 65 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 4ba4696fef..a73714abca 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
-import org.apache.spark.util.AppendOnlyMap
-
+import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
-class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
-
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
@@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+ // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // CoGroupValue is the intermediate state of each value before being merged in compute.
+ private type CoGroup = ArrayBuffer[Any]
+ private type CoGroupValue = (Any, Int) // Int is dependency number
+ private type CoGroupCombiner = Seq[CoGroup]
+
+ private val sparkConf = SparkEnv.get.conf
private var serializerClass: String = null
def setSerializer(cls: String): CoGroupedRDD[K] = {
@@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override val partitioner = Some(part)
- override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
+ val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
- // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ // A list of (rdd iterator, dependency number) pairs
+ val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
- rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
- getSeq(kv._1)(depNum) += kv._2
- }
+ val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+ rddIterators += ((it, depNum))
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
- kv => getSeq(kv._1)(depNum) += kv._2
+ val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ rddIterators += ((it, depNum))
+ }
+ }
+
+ if (!externalSorting) {
+ val map = new AppendOnlyMap[K, CoGroupCombiner]
+ val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup)
+ }
+ val getCombiner: K => CoGroupCombiner = key => {
+ map.changeValue(key, update)
+ }
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ getCombiner(kv._1)(depNum) += kv._2
}
}
+ new InterruptibleIterator(context, map.iterator)
+ } else {
+ val map = createExternalMap(numRdds)
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ map.insert(kv._1, new CoGroupValue(kv._2, depNum))
+ }
+ }
+ new InterruptibleIterator(context, map.iterator)
+ }
+ }
+
+ private def createExternalMap(numRdds: Int)
+ : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = {
+
+ val createCombiner: (CoGroupValue => CoGroupCombiner) = value => {
+ val newCombiner = Array.fill(numRdds)(new CoGroup)
+ value match { case (v, depNum) => newCombiner(depNum) += v }
+ newCombiner
}
- new InterruptibleIterator(context, map.iterator)
+ val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner =
+ (combiner, value) => {
+ value match { case (v, depNum) => combiner(depNum) += v }
+ combiner
+ }
+ val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner =
+ (combiner1, combiner2) => {
+ combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 }
+ }
+ new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner](
+ createCombiner, mergeValue, mergeCombiners)
}
override def clearDependencies() {