diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-08-03 16:37:35 -0400 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-08-03 16:40:45 -0400 |
commit | 6601a6212b65fcd40a0158a6f1df28ae958bbd9e (patch) | |
tree | 8966446841255a79a1bfc4def3c26aa831e7daa5 /core/src/test/scala | |
parent | 1170de37572674cc9a27d4bc2f3eb200475c9088 (diff) | |
download | spark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.tar.gz spark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.tar.bz2 spark-6601a6212b65fcd40a0158a6f1df28ae958bbd9e.zip |
Added a unit test for cross-partition balancing in sort, and changes to
RangePartitioner to make it pass. It turns out that the first partition
was always kind of small due to how we picked partition boundaries.
Diffstat (limited to 'core/src/test/scala')
-rw-r--r-- | core/src/test/scala/spark/SortingSuite.scala | 90 |
1 files changed, 61 insertions, 29 deletions
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index d2dd514edb..a6fdd8a218 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -2,54 +2,86 @@ package spark import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with BeforeAndAfter { +class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging { var sc: SparkContext = _ after { - if(sc != null) { + if (sc != null) { sc.stop() } } test("sortByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0))) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0))) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) } - test("sortLargeArray") { - sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + test("large array") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } - test("sortDescending") { - sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + test("sort descending") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } - test("morePartitionsThanElements") { - sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + test("more partitions than elements") { + sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } - test("emptyRDD") { - sc = new SparkContext("local", "test") - val rand = new scala.util.Random() - val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + test("empty RDD") { + sc = new SparkContext("local", "test") + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("partition balancing") { + sc = new SparkContext("local", "test") + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey() + assert(sorted.collect() === pairArr.sortBy(_._1)) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 150 + partitions(1).length should be > 150 + partitions(2).length should be > 150 + partitions(3).length should be > 150 + partitions(0).last should be < partitions(1).head + partitions(1).last should be < partitions(2).head + partitions(2).last should be < partitions(3).head + } + + test("partition balancing for descending sort") { + sc = new SparkContext("local", "test") + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + assert(sorted.collect() === pairArr.sortBy(_._1).reverse) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 150 + partitions(1).length should be > 150 + partitions(2).length should be > 150 + partitions(3).length should be > 150 + partitions(0).last should be > partitions(1).head + partitions(1).last should be > partitions(2).head + partitions(2).last should be > partitions(3).head } } |