From 806d8a8e980d8ba2f4261bceb393c40bafaa2f73 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 2 Sep 2016 17:14:43 +0200 Subject: [SPARK-16984][SQL] don't try whole dataset immediately when first partition doesn't haveā€¦ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Try increase number of partitions to try so we don't revert to all. ## How was this patch tested? Empirically. This is common case optimization. Author: Robert Kruszewski Closes #14573 from robert3005/robertk/execute-take-backoff. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 7 +++--- .../org/apache/spark/sql/execution/SparkPlan.scala | 28 ++++++++++------------ .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++++ 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2ee13dc4db..10b5f8291a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1296,6 +1296,7 @@ abstract class RDD[T: ClassTag]( * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { + val scaleUpFactor = Math.max(conf.getInt("spark.rdd.limit.scaleUpFactor", 4), 2) if (num == 0) { new Array[T](0) } else { @@ -1310,12 +1311,12 @@ abstract class RDD[T: ClassTag]( // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate // it by 50%. We also cap the estimation in the end. - if (buf.size == 0) { - numPartsToTry = partsScanned * 4 + if (buf.isEmpty) { + numPartsToTry = partsScanned * scaleUpFactor } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7f2e18586d..6a2d97c9b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.{Row, SparkSession, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -316,26 +315,25 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. - if (buf.size == 0) { - numPartsToTry = totalParts - 1 + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2) + if (buf.isEmpty) { + numPartsToTry = partsScanned * limitScaleUpFactor } else { - numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext val res = sc.runJob(childRDD, - (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty[Byte], p) - res.foreach { r => - decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=) - } + buf ++= res.flatMap(decodeUnsafeRows) partsScanned += p.size } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d3440a2644..a54342f82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -116,6 +116,14 @@ object SQLConf { .longConf .createWithDefault(10L * 1024 * 1024) + val LIMIT_SCALE_UP_FACTOR = SQLConfigBuilder("spark.sql.limit.scaleUpFactor") + .internal() + .doc("Minimal increase rate in number of partitions between attempts when executing a take " + + "on a query. Higher values lead to more partitions read. Lower values might lead to " + + "longer execution times as more jobs will be run") + .intConf + .createWithDefault(4) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = SQLConfigBuilder("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -638,6 +646,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) -- cgit v1.2.3