aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-26 19:18:47 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-26 19:18:47 -0700
commit1ef4f0fbd27e54803f14fed1df541fb341daced8 (patch)
treeed3c67a59419f111461b147b4aa81072d81685d4
parent874a9fd407943c7102395cfc64762dfd0ecf9b00 (diff)
downloadspark-1ef4f0fbd27e54803f14fed1df541fb341daced8.tar.gz
spark-1ef4f0fbd27e54803f14fed1df541fb341daced8.tar.bz2
spark-1ef4f0fbd27e54803f14fed1df541fb341daced8.zip
Allow controlling number of splits in sortByKey.
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala9
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala1
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala48
4 files changed, 50 insertions, 12 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index aa1d00c63c..4752bf8d9f 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -435,8 +435,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
extends Logging
with Serializable {
- def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
- new ShuffledSortedRDD(self, ascending)
+ def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
+ new ShuffledSortedRDD(self, ascending, numSplits)
}
}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index be75890a40..7c11925f86 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -16,7 +16,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
- part : Partitioner)
+ part: Partitioner)
extends RDD[(K, C)](parent.context) {
override val partitioner = Some(part)
@@ -38,7 +38,7 @@ abstract class ShuffledRDD[K, V, C](
*/
class RepartitionShuffledRDD[K, V](
@transient parent: RDD[(K, V)],
- part : Partitioner)
+ part: Partitioner)
extends ShuffledRDD[K, V, V](
parent,
Aggregator[K, V, V](null, null, null, false),
@@ -60,10 +60,11 @@ class RepartitionShuffledRDD[K, V](
*/
class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
@transient parent: RDD[(K, V)],
- ascending: Boolean)
+ ascending: Boolean,
+ numSplits: Int)
extends RepartitionShuffledRDD[K, V](
parent,
- new RangePartitioner(parent.splits.size, parent, ascending)) {
+ new RangePartitioner(numSplits, parent, ascending)) {
override def compute(split: Split): Iterator[(K, V)] = {
// By separating this from RepartitionShuffledRDD, we avoided a
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index c7fa8a3874..a2f88fc5e5 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -42,7 +42,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
- //master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 188a9b564e..c87595ecb3 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -17,7 +17,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
test("sortByKey") {
sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
+ val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
@@ -25,18 +25,56 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
- assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey()
+ assert(sorted.splits.size === 2)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
}
+ test("large array with one split") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey(true, 1)
+ assert(sorted.splits.size === 1)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ }
+
+ test("large array with many splits") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ val sorted = pairs.sortByKey(true, 20)
+ assert(sorted.splits.size === 20)
+ assert(sorted.collect() === pairArr.sortBy(_._1))
+ }
+
test("sort descending") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
- val pairs = sc.parallelize(pairArr)
+ val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
+ test("sort descending with one split") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 1)
+ assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ }
+
+ test("sort descending with many splits") {
+ sc = new SparkContext("local", "test")
+ val rand = new scala.util.Random()
+ val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+ val pairs = sc.parallelize(pairArr, 2)
+ assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+ }
+
test("more partitions than elements") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
@@ -48,7 +86,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
test("empty RDD") {
sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
- val pairs = sc.parallelize(pairArr)
+ val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}