aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-07 00:36:04 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-07 00:36:04 -0700
commite10308f5a0a17627317306dfaf19aa20b46490fd (patch)
tree371ea8de2b1480d135082639a59b407754fc9485 /core/src/main
parent4f72066a9ab02308f733bf248b1ca003abcc0874 (diff)
downloadspark-e10308f5a0a17627317306dfaf19aa20b46490fd.tar.gz
spark-e10308f5a0a17627317306dfaf19aa20b46490fd.tar.bz2
spark-e10308f5a0a17627317306dfaf19aa20b46490fd.zip
Make ShuffleDependency.aggregator explicitly optional.
It was confusing to be using new Aggregator[K, V, V](null, null, null, false) to represent the absence of an aggregator.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/Dependency.scala6
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala4
4 files changed, 11 insertions, 7 deletions
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index f2b7aa33ec..19a51dd5b8 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -20,11 +20,15 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Represents a dependency on the output of a shuffle stage.
+ * @param shuffleId the shuffle id
+ * @param rdd the parent RDD
+ * @param aggregator optional aggregator; this allows for map-side combining
+ * @param partitioner partitioner used to partition the shuffle output
*/
class ShuffleDependency[K, V, C](
val shuffleId: Int,
@transient rdd: RDD[(K, V)],
- val aggregator: Aggregator[K, V, C],
+ val aggregator: Option[Aggregator[K, V, C]],
val partitioner: Partitioner)
extends Dependency(rdd)
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 8fa0749184..f1defbe492 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -49,7 +49,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
} else {
logInfo("Adding shuffle dependency with " + rdd)
deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
- context.newShuffleId, rdd, aggr, part)
+ context.newShuffleId, rdd, Some(aggr), part)
}
}
deps.toList
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 769ccf8caa..7577909b83 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -22,7 +22,7 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
*/
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
- aggregator: Aggregator[K, V, C],
+ aggregator: Option[Aggregator[K, V, C]],
part: Partitioner)
extends RDD[(K, C)](parent.context) {
@@ -48,7 +48,7 @@ class RepartitionShuffledRDD[K, V](
part: Partitioner)
extends ShuffledRDD[K, V, V](
parent,
- Aggregator[K, V, V](null, null, null, false),
+ None,
part) {
override def compute(split: Split): Iterator[(K, V)] = {
@@ -95,7 +95,7 @@ class ShuffledAggregatedRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
- extends ShuffledRDD[K, V, C](parent, aggregator, part) {
+ extends ShuffledRDD[K, V, C](parent, Some(aggregator), part) {
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 3e5ba10fd9..86796d3677 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -111,11 +111,11 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner
val bucketIterators =
- if (aggregator.mapSideCombine) {
+ if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) {
+ val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]]
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {