diff options
author | Michael Armbrust <michael@databricks.com> | 2014-07-08 00:41:46 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-07-08 00:41:46 -0700 |
commit | 5a4063645dd7bb4cd8bda890785235729804ab09 (patch) | |
tree | 8c9188d429c103658c196a10f5ac273d69463728 /sql | |
parent | 3cd5029be709307415f911236472a685e406e763 (diff) | |
download | spark-5a4063645dd7bb4cd8bda890785235729804ab09.tar.gz spark-5a4063645dd7bb4cd8bda890785235729804ab09.tar.bz2 spark-5a4063645dd7bb4cd8bda890785235729804ab09.zip |
[SPARK-2391][SQL] Custom take() for LIMIT queries.
Using Spark's take can result in an entire in-memory partition to be shipped in order to retrieve a single row.
Author: Michael Armbrust <michael@databricks.com>
Closes #1318 from marmbrus/takeLimit and squashes the following commits:
77289a5 [Michael Armbrust] Update scala doc
32f0674 [Michael Armbrust] Custom take implementation for LIMIT queries.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala | 51 |
1 files changed, 47 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e8816f0b3c..97abd636ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.DeveloperApi @@ -83,9 +84,9 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * :: DeveloperApi :: * Take the first limit elements. Note that the implementation is different depending on whether * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, - * this operator uses Spark's take method on the Spark driver. If it is not terminal or is - * invoked using execute, we first take the limit on each partition, and then repartition all the - * data to a single partition to compute the global limit. + * this operator uses something similar to Spark's take method on the Spark driver. If it is not + * terminal or is invoked using execute, we first take the limit on each partition, and then + * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) @@ -97,7 +98,49 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext override def output = child.output - override def executeCollect() = child.execute().map(_.copy()).take(limit) + /** + * A custom implementation modeled after the take function on RDDs but which never runs any job + * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few + * rows. + */ + override def executeCollect(): Array[Row] = { + if (limit == 0) { + return new Array[Row](0) + } + + val childRDD = child.execute().map(_.copy()) + + val buf = new ArrayBuffer[Row] + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (buf.size < limit && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = limit - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val sc = sqlContext.sparkContext + val res = + sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + + res.foreach(buf ++= _.take(limit - buf.size)) + partsScanned += numPartsToTry + } + + buf.toArray + } override def execute() = { val rdd = child.execute().mapPartitions { iter => |