aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-02-16 01:16:40 -0600
committerStephen Haberman <stephen@exigencecorp.com>2013-02-16 01:16:40 -0600
commit43288732942a29e7c7c42de66eec6246ea27a13b (patch)
tree7c85562678d88adb20ece72577c4a50bdba3548f
parentc34b8ad2c59697b3e1f5034074e5de0d3b32b8f9 (diff)
downloadspark-43288732942a29e7c7c42de66eec6246ea27a13b.tar.gz
spark-43288732942a29e7c7c42de66eec6246ea27a13b.tar.bz2
spark-43288732942a29e7c7c42de66eec6246ea27a13b.zip
Add assertion about dependencies.
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala16
2 files changed, 14 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 4c41519330..112beb2320 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -62,7 +62,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
- if (Option(partitioner) == self.partitioner) {
+ if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index d6efa3db43..50f2b294bf 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
package spark
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
@@ -105,11 +106,20 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
- val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p)
- val sums = pairs.reduceByKey(p, _+_)
- println(sums.toDebugString)
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
test("join") {