From 42571d30d0d518e69eecf468075e4c5a823a2ae8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Sep 2013 17:09:53 -0700 Subject: Smarter take/limit implementation. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 38 ++++++++++++++++------ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 38 ++++++++++++++++++++++ 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1082cbae3e..1893627ee2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest]( } /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. + * Take the first num elements of the RDD. It works by first scanning one partition, and use the + * results from that partition to estimate the number of additional partitions needed to satisfy + * the limit. */ def take(num: Int): Array[T] = { if (num == 0) { return new Array[T](0) } + val buf = new ArrayBuffer[T] - var p = 0 - while (buf.size < num && p < partitions.size) { + val totalParts = this.partitions.length + var partsScanned = 0 + while (buf.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + val left = num - buf.size - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) - buf ++= res(0) - if (buf.size == num) - return buf.toArray - p += 1 + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + + res.foreach(buf ++= _.take(num - buf.size)) + partsScanned += numPartsToTry } + return buf.toArray } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index c1df5e151e..63adf1cda5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -321,6 +321,44 @@ class RDDSuite extends FunSuite with SharedSparkContext { for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + test("take") { + var nums = sc.makeRDD(Range(1, 1000), 1) + assert(nums.take(0).size === 0) + assert(nums.take(1) === Array(1)) + assert(nums.take(3) === Array(1, 2, 3)) + assert(nums.take(500) === (1 to 500).toArray) + assert(nums.take(501) === (1 to 501).toArray) + assert(nums.take(999) === (1 to 999).toArray) + assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 2) + assert(nums.take(0).size === 0) + assert(nums.take(1) === Array(1)) + assert(nums.take(3) === Array(1, 2, 3)) + assert(nums.take(500) === (1 to 500).toArray) + assert(nums.take(501) === (1 to 501).toArray) + assert(nums.take(999) === (1 to 999).toArray) + assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.take(0).size === 0) + assert(nums.take(1) === Array(1)) + assert(nums.take(3) === Array(1, 2, 3)) + assert(nums.take(500) === (1 to 500).toArray) + assert(nums.take(501) === (1 to 501).toArray) + assert(nums.take(999) === (1 to 999).toArray) + assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 1000) + assert(nums.take(0).size === 0) + assert(nums.take(1) === Array(1)) + assert(nums.take(3) === Array(1, 2, 3)) + assert(nums.take(500) === (1 to 500).toArray) + assert(nums.take(501) === (1 to 501).toArray) + assert(nums.take(999) === (1 to 999).toArray) + assert(nums.take(1000) === (1 to 999).toArray) + } + test("top with predefined ordering") { val nums = Array.range(1, 100000) val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) -- cgit v1.2.3