diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-02-16 10:07:42 -0800 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-02-16 10:07:42 -0800 |
commit | 9d979fb630ae4e91b2ef0a2fefdbe220f99212bf (patch) | |
tree | 7c85562678d88adb20ece72577c4a50bdba3548f | |
parent | beb7ab870858541d736033cc3c6fad4dad657aa3 (diff) | |
parent | 43288732942a29e7c7c42de66eec6246ea27a13b (diff) | |
download | spark-9d979fb630ae4e91b2ef0a2fefdbe220f99212bf.tar.gz spark-9d979fb630ae4e91b2ef0a2fefdbe220f99212bf.tar.bz2 spark-9d979fb630ae4e91b2ef0a2fefdbe220f99212bf.zip |
Merge pull request #469 from stephenh/samepartitionercombine
If combineByKey is using the same partitioner, skip the shuffle.
-rw-r--r-- | core/src/main/scala/spark/PairRDDFunctions.scala | 4 | ||||
-rw-r--r-- | core/src/test/scala/spark/ShuffleSuite.scala | 23 |
2 files changed, 26 insertions, 1 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index cc3cca2571..112beb2320 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (mapSideCombine) { + if (self.partitioner == Some(partitioner)) { + self.mapPartitions(aggregator.combineValuesByKey(_), true) + } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 3493b9511f..50f2b294bf 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -1,6 +1,7 @@ package spark import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers @@ -98,6 +99,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } + + test("reduceByKey with partitioner") { + sc = new SparkContext("local", "test") + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } test("join") { sc = new SparkContext("local", "test") |