From c34b8ad2c59697b3e1f5034074e5de0d3b32b8f9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 16 Feb 2013 00:54:03 -0600 Subject: Avoid a shuffle if combineByKey is passed the same partitioner. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 +++- core/src/test/scala/spark/ShuffleSuite.scala | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index cc3cca2571..4c41519330 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -62,7 +62,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (mapSideCombine) { + if (Option(partitioner) == self.partitioner) { + self.mapPartitions(aggregator.combineValuesByKey(_), true) + } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 3493b9511f..d6efa3db43 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -98,6 +98,19 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } + + test("reduceByKey with partitioner") { + sc = new SparkContext("local", "test") + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p) + val sums = pairs.reduceByKey(p, _+_) + println(sums.toDebugString) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + } test("join") { sc = new SparkContext("local", "test") -- cgit v1.2.3