aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/ShuffledRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/ShuffledRDD.scala')
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala72
1 files changed, 66 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 3616d8e47e..a7346060b3 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -1,29 +1,89 @@
package spark
+import scala.collection.mutable.ArrayBuffer
import java.util.{HashMap => JHashMap}
+
class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
override def hashCode(): Int = idx
}
-class ShuffledRDD[K, V, C](
+
+/**
+ * The resulting RDD from a shuffle (e.g. repartitioning of data).
+ */
+abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
- part : Partitioner)
+ part : Partitioner)
extends RDD[(K, C)](parent.context) {
- //override val partitioner = Some(part)
+
override val partitioner = Some(part)
-
+
@transient
val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def splits = splits_
-
+
override def preferredLocations(split: Split) = Nil
-
+
val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
override val dependencies = List(dep)
+}
+
+
+/**
+ * Repartition a key-value pair RDD.
+ */
+class RepartitionShuffledRDD[K, V](
+ @transient parent: RDD[(K, V)],
+ part : Partitioner)
+ extends ShuffledRDD[K, V, V](
+ parent,
+ Aggregator[K, V, V](null, null, null, false),
+ part) {
+
+ override def compute(split: Split): Iterator[(K, V)] = {
+ val buf = new ArrayBuffer[(K, V)]
+ val fetcher = SparkEnv.get.shuffleFetcher
+ def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
+ fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
+ buf.iterator
+ }
+}
+
+
+/**
+ * A sort-based shuffle (that doesn't apply aggregation). It does so by first
+ * repartitioning the RDD by range, and then sort within each range.
+ */
+class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
+ @transient parent: RDD[(K, V)],
+ ascending: Boolean)
+ extends RepartitionShuffledRDD[K, V](
+ parent,
+ new RangePartitioner(parent.splits.size, parent, ascending)) {
+
+ override def compute(split: Split): Iterator[(K, V)] = {
+ // By separating this from RepartitionShuffledRDD, we avoided a
+ // buf.iterator.toArray call, thus avoiding building up the buffer twice.
+ val buf = new ArrayBuffer[(K, V)]
+ def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
+ SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
+ buf.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
+ }
+}
+
+
+/**
+ * The resulting RDD from shuffle and running (hash-based) aggregation.
+ */
+class ShuffledAggregatedRDD[K, V, C](
+ @transient parent: RDD[(K, V)],
+ aggregator: Aggregator[K, V, C],
+ part : Partitioner)
+ extends ShuffledRDD[K, V, C](parent, aggregator, part) {
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]