diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala | 88 |
1 files changed, 65 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4ba4696fef..a73714abca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} -import org.apache.spark.util.AppendOnlyMap - +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } - /** * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. @@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). + // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. + // CoGroupValue is the intermediate state of each value before being merged in compute. + private type CoGroup = ArrayBuffer[Any] + private type CoGroupValue = (Any, Int) // Int is dependency number + private type CoGroupCombiner = Seq[CoGroup] + + private val sparkConf = SparkEnv.get.conf private var serializerClass: String = null def setSerializer(cls: String): CoGroupedRDD[K] = { @@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) - } - - val getSeq = (k: K) => { - map.changeValue(k, update) - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + // A list of (rdd iterator, dependency number) pairs + val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } + val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + rddIterators += ((it, depNum)) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 + val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + rddIterators += ((it, depNum)) + } + } + + if (!externalSorting) { + val map = new AppendOnlyMap[K, CoGroupCombiner] + val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) + } + val getCombiner: K => CoGroupCombiner = key => { + map.changeValue(key, update) + } + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + getCombiner(kv._1)(depNum) += kv._2 } } + new InterruptibleIterator(context, map.iterator) + } else { + val map = createExternalMap(numRdds) + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + map.insert(kv._1, new CoGroupValue(kv._2, depNum)) + } + } + new InterruptibleIterator(context, map.iterator) + } + } + + private def createExternalMap(numRdds: Int) + : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { + + val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { + val newCombiner = Array.fill(numRdds)(new CoGroup) + value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner } - new InterruptibleIterator(context, map.iterator) + val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = + (combiner, value) => { + value match { case (v, depNum) => combiner(depNum) += v } + combiner + } + val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = + (combiner1, combiner2) => { + combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + } + new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( + createCombiner, mergeValue, mergeCombiners) } override def clearDependencies() { |