aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2013-09-26 08:01:04 -0700
committerPatrick Wendell <pwendell@gmail.com>2013-09-26 08:01:04 -0700
commit6566a19b38204d754c5e8f821b4276616e90abc6 (patch)
tree1826a4662827321f6ac4424751fd13ba79ee320b
parent9d34838bde49488629f6a97f097d45c19b5d600c (diff)
parent42571d30d0d518e69eecf468075e4c5a823a2ae8 (diff)
downloadspark-6566a19b38204d754c5e8f821b4276616e90abc6.tar.gz
spark-6566a19b38204d754c5e8f821b4276616e90abc6.tar.bz2
spark-6566a19b38204d754c5e8f821b4276616e90abc6.zip
Merge pull request #9 from rxin/limit
Smarter take/limit implementation.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala38
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala38
2 files changed, 66 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1082cbae3e..1893627ee2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
- * it will be slow if a lot of partitions are required. In that case, use collect() to get the
- * whole RDD instead.
+ * Take the first num elements of the RDD. It works by first scanning one partition, and use the
+ * results from that partition to estimate the number of additional partitions needed to satisfy
+ * the limit.
*/
def take(num: Int): Array[T] = {
if (num == 0) {
return new Array[T](0)
}
+
val buf = new ArrayBuffer[T]
- var p = 0
- while (buf.size < num && p < partitions.size) {
+ val totalParts = this.partitions.length
+ var partsScanned = 0
+ while (buf.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
val left = num - buf.size
- val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
- buf ++= res(0)
- if (buf.size == num)
- return buf.toArray
- p += 1
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+
+ res.foreach(buf ++= _.take(num - buf.size))
+ partsScanned += numPartsToTry
}
+
return buf.toArray
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 016db8d57e..6d1bc5e296 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -320,6 +320,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+ test("take") {
+ var nums = sc.makeRDD(Range(1, 1000), 1)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 2)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 100)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 1000)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+ }
+
test("top with predefined ordering") {
val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)