aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-07-08 00:41:46 -0700
committerReynold Xin <rxin@apache.org>2014-07-08 00:41:46 -0700
commit5a4063645dd7bb4cd8bda890785235729804ab09 (patch)
tree8c9188d429c103658c196a10f5ac273d69463728 /sql
parent3cd5029be709307415f911236472a685e406e763 (diff)
downloadspark-5a4063645dd7bb4cd8bda890785235729804ab09.tar.gz
spark-5a4063645dd7bb4cd8bda890785235729804ab09.tar.bz2
spark-5a4063645dd7bb4cd8bda890785235729804ab09.zip
[SPARK-2391][SQL] Custom take() for LIMIT queries.
Using Spark's take can result in an entire in-memory partition to be shipped in order to retrieve a single row. Author: Michael Armbrust <michael@databricks.com> Closes #1318 from marmbrus/takeLimit and squashes the following commits: 77289a5 [Michael Armbrust] Update scala doc 32f0674 [Michael Armbrust] Custom take implementation for LIMIT queries.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala51
1 files changed, 47 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e8816f0b3c..97abd636ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.annotation.DeveloperApi
@@ -83,9 +84,9 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
* :: DeveloperApi ::
* Take the first limit elements. Note that the implementation is different depending on whether
* this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
- * invoked using execute, we first take the limit on each partition, and then repartition all the
- * data to a single partition to compute the global limit.
+ * this operator uses something similar to Spark's take method on the Spark driver. If it is not
+ * terminal or is invoked using execute, we first take the limit on each partition, and then
+ * repartition all the data to a single partition to compute the global limit.
*/
@DeveloperApi
case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
@@ -97,7 +98,49 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
override def output = child.output
- override def executeCollect() = child.execute().map(_.copy()).take(limit)
+ /**
+ * A custom implementation modeled after the take function on RDDs but which never runs any job
+ * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
+ * rows.
+ */
+ override def executeCollect(): Array[Row] = {
+ if (limit == 0) {
+ return new Array[Row](0)
+ }
+
+ val childRDD = child.execute().map(_.copy())
+
+ val buf = new ArrayBuffer[Row]
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ while (buf.size < limit && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = limit - buf.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val sc = sqlContext.sparkContext
+ val res =
+ sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+
+ res.foreach(buf ++= _.take(limit - buf.size))
+ partsScanned += numPartsToTry
+ }
+
+ buf.toArray
+ }
override def execute() = {
val rdd = child.execute().mapPartitions { iter =>