aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala144
1 files changed, 144 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
new file mode 100644
index 0000000000..0187256a8e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.{ObjectOutputStream, IOException}
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConversions
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+
+
+private[spark] sealed trait CoGroupSplitDep extends Serializable
+
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Partition
+ ) extends CoGroupSplitDep {
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.partitions(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
+private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+
+private[spark]
+class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+ extends Partition with Serializable {
+ override val index: Int = idx
+ override def hashCode(): Int = idx
+}
+
+
+/**
+ * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
+ * tuple with the list of values for that key.
+ *
+ * @param rdds parent RDDs.
+ * @param part partitioner used to partition the shuffle output.
+ */
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): CoGroupedRDD[K] = {
+ serializerClass = cls
+ this
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ rdds.map { rdd: RDD[_ <: Product2[K, _]] =>
+ if (rdd.partitioner == Some(part)) {
+ logDebug("Adding one-to-one dependency with " + rdd)
+ new OneToOneDependency(rdd)
+ } else {
+ logDebug("Adding shuffle dependency with " + rdd)
+ new ShuffleDependency[Any, Any](rdd, part, serializerClass)
+ }
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
+ for (i <- 0 until array.size) {
+ // Each CoGroupPartition will have a dependency per contributing RDD
+ array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
+ // Assume each RDD contributed a single dependency, and get it
+ dependencies(j) match {
+ case s: ShuffleDependency[_, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleId)
+ case _ =>
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ }
+ }.toArray)
+ }
+ array
+ }
+
+ override val partitioner = Some(part)
+
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ val split = s.asInstanceOf[CoGroupPartition]
+ val numRdds = split.deps.size
+ // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
+ val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+
+ def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
+ map.put(k, seq)
+ seq
+ }
+ }
+
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
+ for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ // Read them from the parent
+ rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
+ getSeq(kv._1)(depNum) += kv._2
+ }
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ // Read map outputs of shuffle
+ val fetcher = SparkEnv.get.shuffleFetcher
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ kv => getSeq(kv._1)(depNum) += kv._2
+ }
+ }
+ }
+ JavaConversions.mapAsScalaMap(map).iterator
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}