From 43288732942a29e7c7c42de66eec6246ea27a13b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 16 Feb 2013 01:16:40 -0600 Subject: Add assertion about dependencies. --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- core/src/test/scala/spark/ShuffleSuite.scala | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 4c41519330..112beb2320 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -62,7 +62,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (Option(partitioner) == self.partitioner) { + if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index d6efa3db43..50f2b294bf 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -1,6 +1,7 @@ package spark import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers @@ -105,11 +106,20 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { def numPartitions = 2 def getPartition(key: Any) = key.asInstanceOf[Int] } - val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p) - val sums = pairs.reduceByKey(p, _+_) - println(sums.toDebugString) + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) assert(sums.collect().toSet === Set((1, 4), (0, 1))) assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection } test("join") { -- cgit v1.2.3