aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-08-03 16:37:35 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-08-03 16:40:45 -0400
commit6601a6212b65fcd40a0158a6f1df28ae958bbd9e (patch)
tree8966446841255a79a1bfc4def3c26aa831e7daa5 /core/src/test/scala
parent1170de37572674cc9a27d4bc2f3eb200475c9088 (diff)
downloadspark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.tar.gz
spark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.tar.bz2
spark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.zip
Added a unit test for cross-partition balancing in sort, and changes to
RangePartitioner to make it pass. It turns out that the first partition was always kind of small due to how we picked partition boundaries.
Diffstat (limited to 'core/src/test/scala')
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala90
1 files changed, 61 insertions, 29 deletions
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index d2dd514edb..a6fdd8a218 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -2,54 +2,86 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with BeforeAndAfter {
+class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging {
var sc: SparkContext = _
after {
- if(sc != null) {
+ if (sc != null) {
sc.stop()
}
}
test("sortByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+ sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
- test("sortLargeArray") {
- sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ test("large array") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
- test("sortDescending") {
- sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ test("sort descending") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
- test("morePartitionsThanElements") {
- sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr, 30)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ test("more partitions than elements") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 30)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
- test("emptyRDD") {
- sc = new SparkContext("local", "test")
- val rand = new scala.util.Random()
- val pairArr = new Array[(Int, Int)](0)
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ test("empty RDD") {
+ sc = new SparkContext("local", "test")
+ val pairArr = new Array[(Int, Int)](0)
+ val pairs = sc.parallelize(pairArr)
+ assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ }
+
+ test("partition balancing") {
+ sc = new SparkContext("local", "test")
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey()
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ val partitions = sorted.collectPartitions()
+ logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 150
+ partitions(1).length should be > 150
+ partitions(2).length should be > 150
+ partitions(3).length should be > 150
+ partitions(0).last should be < partitions(1).head
+ partitions(1).last should be < partitions(2).head
+ partitions(2).last should be < partitions(3).head
+ }
+
+ test("partition balancing for descending sort") {
+ sc = new SparkContext("local", "test")
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
+ assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
+ val partitions = sorted.collectPartitions()
+ logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
+ partitions(0).length should be > 150
+ partitions(1).length should be > 150
+ partitions(2).length should be > 150
+ partitions(3).length should be > 150
+ partitions(0).last should be > partitions(1).head
+ partitions(1).last should be > partitions(2).head
+ partitions(2).last should be > partitions(3).head
}
}