aboutsummaryrefslogtreecommitdiff
path: root/src/test/spark/ShuffleSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/spark/ShuffleSuite.scala')
-rw-r--r--src/test/spark/ShuffleSuite.scala130
1 files changed, 130 insertions, 0 deletions
diff --git a/src/test/spark/ShuffleSuite.scala b/src/test/spark/ShuffleSuite.scala
new file mode 100644
index 0000000000..a5773614e8
--- /dev/null
+++ b/src/test/spark/ShuffleSuite.scala
@@ -0,0 +1,130 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+class ShuffleSuite extends FunSuite {
+ test("groupByKey") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("join") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("join with no matches") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+}