aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala30
2 files changed, 44 insertions, 29 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 7021a339e8..658e8c8b89 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
import org.apache.spark.util.Utils
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleHandle
-
-private[spark] sealed trait CoGroupSplitDep extends Serializable
+/** The references to rdd and splitIndex are transient because redundant information is stored
+ * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
+ * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
+ * task closure. */
private[spark] case class NarrowCoGroupSplitDep(
- rdd: RDD[_],
- splitIndex: Int,
+ @transient rdd: RDD[_],
+ @transient splitIndex: Int,
var split: Partition
- ) extends CoGroupSplitDep {
+ ) extends Serializable {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}
-private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
-
-private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+/**
+ * Stores information about the narrow dependencies used by a CoGroupedRdd.
+ *
+ * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one
+ * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing
+ * the partition for that dependency) at the corresponding index. The size of
+ * narrowDeps should always be equal to the number of parents.
+ */
+private[spark] class CoGroupPartition(
+ idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
@@ -105,9 +113,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleHandle)
+ None
case _ =>
- new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
}
}.toArray)
}
@@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val sparkConf = SparkEnv.get.conf
val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true)
val split = s.asInstanceOf[CoGroupPartition]
- val numRdds = split.deps.length
+ val numRdds = dependencies.length
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
- for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
+ case oneToOneDependency: OneToOneDependency[Product2[K, Any]] =>
+ val dependencyPartition = split.narrowDeps(depNum).get.split
// Read them from the parent
- val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+ val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
rddIterators += ((it, depNum))
- case ShuffleCoGroupSplitDep(handle) =>
+ case shuffleDependency: ShuffleDependency[_, _, _] =>
// Read map outputs of shuffle
val it = SparkEnv.get.shuffleManager
- .getReader(handle, split.index, split.index + 1, context)
+ .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
.read()
rddIterators += ((it, depNum))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index e9d745588e..633aeba3bb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleHandle)
+ None
case _ =>
- new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
}
}.toArray)
}
@@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
seq
}
}
- def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
- rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
+ def integrate(depNum: Int, op: Product2[K, V] => Unit) = {
+ dependencies(depNum) match {
+ case oneToOneDependency: OneToOneDependency[_] =>
+ val dependencyPartition = partition.narrowDeps(depNum).get.split
+ oneToOneDependency.rdd.iterator(dependencyPartition, context)
+ .asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- case ShuffleCoGroupSplitDep(handle) =>
- val iter = SparkEnv.get.shuffleManager
- .getReader(handle, partition.index, partition.index + 1, context)
- .read()
- iter.foreach(op)
+ case shuffleDependency: ShuffleDependency[_, _, _] =>
+ val iter = SparkEnv.get.shuffleManager
+ .getReader(
+ shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
+ .read()
+ iter.foreach(op)
+ }
}
+
// the first dep is rdd1; add all values to the map
- integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+ integrate(0, t => getSeq(t._1) += t._2)
// the second dep is rdd2; remove all of its keys
- integrate(partition.deps(1), t => map.remove(t._1))
+ integrate(1, t => map.remove(t._1))
map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
}