aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorRon Hu <ron.hu@huawei.com>2017-02-23 20:18:21 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-23 20:18:21 -0800
commitd7e43b613aeb6d85b8522d7c16c6e0807843c964 (patch)
tree7c5e32d967d8a080151fbdfa25933fd00c7111a1 /sql/catalyst
parent2f69e3f60f27d4598f001a5454abc21f739120a6 (diff)
downloadspark-d7e43b613aeb6d85b8522d7c16c6e0807843c964.tar.gz
spark-d7e43b613aeb6d85b8522d7c16c6e0807843c964.tar.bz2
spark-d7e43b613aeb6d85b8522d7c16c6e0807843c964.zip
[SPARK-17075][SQL] implemented filter estimation
## What changes were proposed in this pull request? We traverse predicate and evaluate the logical expressions to compute the selectivity of a FILTER operator. ## How was this patch tested? We add a new test suite to test various logical operators. Author: Ron Hu <ron.hu@huawei.com> Closes #16395 from ron8hu/filterSelectivity.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala511
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala403
4 files changed, 939 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ce1c55dc08..ccebae3cc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -129,6 +129,14 @@ case class Filter(condition: Expression, child: LogicalPlan)
.filterNot(SubqueryExpression.hasCorrelatedSubquery)
child.constraints.union(predicates.toSet)
}
+
+ override def computeStats(conf: CatalystConf): Statistics = {
+ if (conf.cboEnabled) {
+ FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
+ } else {
+ super.computeStats(conf)
+ }
+ }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
new file mode 100644
index 0000000000..fcc607a610
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.immutable.{HashSet, Map}
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging {
+
+ /**
+ * We use a mutable colStats because we need to update the corresponding ColumnStat
+ * for a column after we apply a predicate condition. For example, column c has
+ * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50),
+ * we need to set the column's [min, max] value to [40, 100] after we evaluate the
+ * first condition c > 40. We need to set the column's [min, max] value to [40, 50]
+ * after we evaluate the second condition c <= 50.
+ */
+ private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty
+
+ /**
+ * Returns an option of Statistics for a Filter logical plan node.
+ * For a given compound expression condition, this method computes filter selectivity
+ * (or the percentage of rows meeting the filter condition), which
+ * is used to compute row count, size in bytes, and the updated statistics after a given
+ * predicated is applied.
+ *
+ * @return Option[Statistics] When there is no statistics collected, it returns None.
+ */
+ def estimate: Option[Statistics] = {
+ // We first copy child node's statistics and then modify it based on filter selectivity.
+ val stats: Statistics = plan.child.stats(catalystConf)
+ if (stats.rowCount.isEmpty) return None
+
+ // save a mutable copy of colStats so that we can later change it recursively
+ mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*)
+
+ // estimate selectivity of this filter predicate
+ val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
+ case Some(percent) => percent
+ // for not-supported condition, set filter selectivity to a conservative estimate 100%
+ case None => 1.0
+ }
+
+ // attributeStats has mapping Attribute-to-ColumnStat.
+ // mutableColStats has mapping ExprId-to-ColumnStat.
+ // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat
+ val expridToAttrMap: Map[ExprId, Attribute] =
+ stats.attributeStats.map(kv => (kv._1.exprId, kv._1))
+ // copy mutableColStats contents to an immutable AttributeMap.
+ val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] =
+ mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2)
+ val newColStats = AttributeMap(mutableAttributeStats.toSeq)
+
+ val filteredRowCount: BigInt =
+ EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity)
+ val filteredSizeInBytes =
+ EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
+
+ Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
+ attributeStats = newColStats))
+ }
+
+ /**
+ * Returns a percentage of rows meeting a compound condition in Filter node.
+ * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT.
+ * For logical AND conditions, we need to update stats after a condition estimation
+ * so that the stats will be more accurate for subsequent estimation. This is needed for
+ * range condition such as (c > 40 AND c <= 50)
+ * For logical OR conditions, we do not update stats after a condition estimation.
+ *
+ * @param condition the compound logical expression
+ * @param update a boolean flag to specify if we need to update ColumnStat of a column
+ * for subsequent conditions
+ * @return a double value to show the percentage of rows meeting a given condition.
+ * It returns None if the condition is not supported.
+ */
+ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
+
+ condition match {
+ case And(cond1, cond2) =>
+ (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update))
+ match {
+ case (Some(p1), Some(p2)) => Some(p1 * p2)
+ case (Some(p1), None) => Some(p1)
+ case (None, Some(p2)) => Some(p2)
+ case (None, None) => None
+ }
+
+ case Or(cond1, cond2) =>
+ // For ease of debugging, we compute percent1 and percent2 in 2 statements.
+ val percent1 = calculateFilterSelectivity(cond1, update = false)
+ val percent2 = calculateFilterSelectivity(cond2, update = false)
+ (percent1, percent2) match {
+ case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2)))
+ case (Some(p1), None) => Some(1.0)
+ case (None, Some(p2)) => Some(1.0)
+ case (None, None) => None
+ }
+
+ case Not(cond) => calculateFilterSelectivity(cond, update = false) match {
+ case Some(percent) => Some(1.0 - percent)
+ // for not-supported condition, set filter selectivity to a conservative estimate 100%
+ case None => None
+ }
+
+ case _ =>
+ calculateSingleCondition(condition, update)
+ }
+ }
+
+ /**
+ * Returns a percentage of rows meeting a single condition in Filter node.
+ * Currently we only support binary predicates where one side is a column,
+ * and the other is a literal.
+ *
+ * @param condition a single logical expression
+ * @param update a boolean flag to specify if we need to update ColumnStat of a column
+ * for subsequent conditions
+ * @return Option[Double] value to show the percentage of rows meeting a given condition.
+ * It returns None if the condition is not supported.
+ */
+ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
+ condition match {
+ // For evaluateBinary method, we assume the literal on the right side of an operator.
+ // So we will change the order if not.
+
+ // EqualTo does not care about the order
+ case op @ EqualTo(ar: AttributeReference, l: Literal) =>
+ evaluateBinary(op, ar, l, update)
+ case op @ EqualTo(l: Literal, ar: AttributeReference) =>
+ evaluateBinary(op, ar, l, update)
+
+ case op @ LessThan(ar: AttributeReference, l: Literal) =>
+ evaluateBinary(op, ar, l, update)
+ case op @ LessThan(l: Literal, ar: AttributeReference) =>
+ evaluateBinary(GreaterThan(ar, l), ar, l, update)
+
+ case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
+ evaluateBinary(op, ar, l, update)
+ case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) =>
+ evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update)
+
+ case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
+ evaluateBinary(op, ar, l, update)
+ case op @ GreaterThan(l: Literal, ar: AttributeReference) =>
+ evaluateBinary(LessThan(ar, l), ar, l, update)
+
+ case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
+ evaluateBinary(op, ar, l, update)
+ case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) =>
+ evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)
+
+ case In(ar: AttributeReference, expList)
+ if expList.forall(e => e.isInstanceOf[Literal]) =>
+ // Expression [In (value, seq[Literal])] will be replaced with optimized version
+ // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
+ // Here we convert In into InSet anyway, because they share the same processing logic.
+ val hSet = expList.map(e => e.eval())
+ evaluateInSet(ar, HashSet() ++ hSet, update)
+
+ case InSet(ar: AttributeReference, set) =>
+ evaluateInSet(ar, set, update)
+
+ case IsNull(ar: AttributeReference) =>
+ evaluateIsNull(ar, isNull = true, update)
+
+ case IsNotNull(ar: AttributeReference) =>
+ evaluateIsNull(ar, isNull = false, update)
+
+ case _ =>
+ // TODO: it's difficult to support string operators without advanced statistics.
+ // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _)
+ // | EndsWith(_, _) are not supported yet
+ logDebug("[CBO] Unsupported filter condition: " + condition)
+ None
+ }
+ }
+
+ /**
+ * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition.
+ *
+ * @param attrRef an AttributeReference (or a column)
+ * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition
+ * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ * It returns None if no statistics collected for a given column.
+ */
+ def evaluateIsNull(
+ attrRef: AttributeReference,
+ isNull: Boolean,
+ update: Boolean)
+ : Option[Double] = {
+ if (!mutableColStats.contains(attrRef.exprId)) {
+ logDebug("[CBO] No statistics for " + attrRef)
+ return None
+ }
+ val aColStat = mutableColStats(attrRef.exprId)
+ val rowCountValue = plan.child.stats(catalystConf).rowCount.get
+ val nullPercent: BigDecimal =
+ if (rowCountValue == 0) 0.0
+ else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue)
+
+ if (update) {
+ val newStats =
+ if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None)
+ else aColStat.copy(nullCount = 0)
+
+ mutableColStats += (attrRef.exprId -> newStats)
+ }
+
+ val percent =
+ if (isNull) {
+ nullPercent.toDouble
+ }
+ else {
+ /** ISNOTNULL(column) */
+ 1.0 - nullPercent.toDouble
+ }
+
+ Some(percent)
+ }
+
+ /**
+ * Returns a percentage of rows meeting a binary comparison expression.
+ *
+ * @param op a binary comparison operator uch as =, <, <=, >, >=
+ * @param attrRef an AttributeReference (or a column)
+ * @param literal a literal value (or constant)
+ * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ * It returns None if no statistics exists for a given column or wrong value.
+ */
+ def evaluateBinary(
+ op: BinaryComparison,
+ attrRef: AttributeReference,
+ literal: Literal,
+ update: Boolean)
+ : Option[Double] = {
+ if (!mutableColStats.contains(attrRef.exprId)) {
+ logDebug("[CBO] No statistics for " + attrRef)
+ return None
+ }
+
+ op match {
+ case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update)
+ case _ =>
+ attrRef.dataType match {
+ case _: NumericType | DateType | TimestampType =>
+ evaluateBinaryForNumeric(op, attrRef, literal, update)
+ case StringType | BinaryType =>
+ // TODO: It is difficult to support other binary comparisons for String/Binary
+ // type without min/max and advanced statistics like histogram.
+ logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef)
+ None
+ }
+ }
+ }
+
+ /**
+ * For a SQL data type, its internal data type may be different from its external type.
+ * For DateType, its internal type is Int, and its external data type is Java Date type.
+ * The min/max values in ColumnStat are saved in their corresponding external type.
+ *
+ * @param attrDataType the column data type
+ * @param litValue the literal value
+ * @return a BigDecimal value
+ */
+ def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
+ attrDataType match {
+ case DateType =>
+ Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
+ case TimestampType =>
+ Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
+ case StringType | BinaryType =>
+ None
+ case _ =>
+ Some(litValue)
+ }
+ }
+
+ /**
+ * Returns a percentage of rows meeting an equality (=) expression.
+ * This method evaluates the equality predicate for all data types.
+ *
+ * @param attrRef an AttributeReference (or a column)
+ * @param literal a literal value (or constant)
+ * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ */
+ def evaluateEqualTo(
+ attrRef: AttributeReference,
+ literal: Literal,
+ update: Boolean)
+ : Option[Double] = {
+
+ val aColStat = mutableColStats(attrRef.exprId)
+ val ndv = aColStat.distinctCount
+
+ // decide if the value is in [min, max] of the column.
+ // We currently don't store min/max for binary/string type.
+ // Hence, we assume it is in boundary for binary/string type.
+ val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType)
+ val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal)
+
+ if (inBoundary) {
+
+ if (update) {
+ // We update ColumnStat structure after apply this equality predicate.
+ // Set distinctCount to 1. Set nullCount to 0.
+ // Need to save new min/max using the external type value of the literal
+ val newValue = convertBoundValue(attrRef.dataType, literal.value)
+ val newStats = aColStat.copy(distinctCount = 1, min = newValue,
+ max = newValue, nullCount = 0)
+ mutableColStats += (attrRef.exprId -> newStats)
+ }
+
+ Some(1.0 / ndv.toDouble)
+ } else {
+ Some(0.0)
+ }
+
+ }
+
+ /**
+ * Returns a percentage of rows meeting "IN" operator expression.
+ * This method evaluates the equality predicate for all data types.
+ *
+ * @param attrRef an AttributeReference (or a column)
+ * @param hSet a set of literal values
+ * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ * It returns None if no statistics exists for a given column.
+ */
+
+ def evaluateInSet(
+ attrRef: AttributeReference,
+ hSet: Set[Any],
+ update: Boolean)
+ : Option[Double] = {
+ if (!mutableColStats.contains(attrRef.exprId)) {
+ logDebug("[CBO] No statistics for " + attrRef)
+ return None
+ }
+
+ val aColStat = mutableColStats(attrRef.exprId)
+ val ndv = aColStat.distinctCount
+ val aType = attrRef.dataType
+ var newNdv: Long = 0
+
+ // use [min, max] to filter the original hSet
+ aType match {
+ case _: NumericType | DateType | TimestampType =>
+ val statsRange =
+ Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange]
+
+ // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal.
+ // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec.
+ val hSetBigdec = hSet.map(e => BigDecimal(e.toString))
+ val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max)
+ // We use hSetBigdecToAnyMap to help us find the original hSet value.
+ val hSetBigdecToAnyMap: Map[BigDecimal, Any] =
+ hSet.map(e => BigDecimal(e.toString) -> e).toMap
+
+ if (validQuerySet.isEmpty) {
+ return Some(0.0)
+ }
+
+ // Need to save new min/max using the external type value of the literal
+ val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max))
+ val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min))
+
+ // newNdv should not be greater than the old ndv. For example, column has only 2 values
+ // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
+ newNdv = math.min(validQuerySet.size.toLong, ndv.longValue())
+ if (update) {
+ val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+ max = newMax, nullCount = 0)
+ mutableColStats += (attrRef.exprId -> newStats)
+ }
+
+ // We assume the whole set since there is no min/max information for String/Binary type
+ case StringType | BinaryType =>
+ newNdv = math.min(hSet.size.toLong, ndv.longValue())
+ if (update) {
+ val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0)
+ mutableColStats += (attrRef.exprId -> newStats)
+ }
+ }
+
+ // return the filter selectivity. Without advanced statistics such as histograms,
+ // we have to assume uniform distribution.
+ Some(math.min(1.0, newNdv.toDouble / ndv.toDouble))
+ }
+
+ /**
+ * Returns a percentage of rows meeting a binary comparison expression.
+ * This method evaluate expression for Numeric columns only.
+ *
+ * @param op a binary comparison operator uch as =, <, <=, >, >=
+ * @param attrRef an AttributeReference (or a column)
+ * @param literal a literal value (or constant)
+ * @param update a boolean flag to specify if we need to update ColumnStat of a given column
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ */
+ def evaluateBinaryForNumeric(
+ op: BinaryComparison,
+ attrRef: AttributeReference,
+ literal: Literal,
+ update: Boolean)
+ : Option[Double] = {
+
+ var percent = 1.0
+ val aColStat = mutableColStats(attrRef.exprId)
+ val ndv = aColStat.distinctCount
+ val statsRange =
+ Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange]
+
+ // determine the overlapping degree between predicate range and column's range
+ val literalValueBD = BigDecimal(literal.value.toString)
+ val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
+ case _: LessThan =>
+ (literalValueBD <= statsRange.min, literalValueBD > statsRange.max)
+ case _: LessThanOrEqual =>
+ (literalValueBD < statsRange.min, literalValueBD >= statsRange.max)
+ case _: GreaterThan =>
+ (literalValueBD >= statsRange.max, literalValueBD < statsRange.min)
+ case _: GreaterThanOrEqual =>
+ (literalValueBD > statsRange.max, literalValueBD <= statsRange.min)
+ }
+
+ if (noOverlap) {
+ percent = 0.0
+ } else if (completeOverlap) {
+ percent = 1.0
+ } else {
+ // this is partial overlap case
+ var newMax = aColStat.max
+ var newMin = aColStat.min
+ var newNdv = ndv
+ val literalToDouble = literalValueBD.toDouble
+ val maxToDouble = BigDecimal(statsRange.max).toDouble
+ val minToDouble = BigDecimal(statsRange.min).toDouble
+
+ // Without advanced statistics like histogram, we assume uniform data distribution.
+ // We just prorate the adjusted range over the initial range to compute filter selectivity.
+ // For ease of computation, we convert all relevant numeric values to Double.
+ percent = op match {
+ case _: LessThan =>
+ (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+ case _: LessThanOrEqual =>
+ if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble
+ else (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+ case _: GreaterThan =>
+ (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+ case _: GreaterThanOrEqual =>
+ if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble
+ else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+ }
+
+ // Need to save new min/max using the external type value of the literal
+ val newValue = convertBoundValue(attrRef.dataType, literal.value)
+
+ if (update) {
+ op match {
+ case _: GreaterThan => newMin = newValue
+ case _: GreaterThanOrEqual => newMin = newValue
+ case _: LessThan => newMax = newValue
+ case _: LessThanOrEqual => newMax = newValue
+ }
+
+ newNdv = math.max(math.round(ndv.toDouble * percent), 1)
+ val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+ max = newMax, nullCount = 0)
+
+ mutableColStats += (attrRef.exprId -> newStats)
+ }
+ }
+
+ Some(percent)
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 5aa6b9353b..4557114532 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import java.math.{BigDecimal => JDecimal}
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
@@ -57,6 +58,20 @@ object Range {
n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
}
+ def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match {
+ case _: DefaultRange => true
+ case _: NullRange => false
+ case n: NumericRange =>
+ val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) {
+ if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
+ } else {
+ assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] ||
+ lit.dataType.isInstanceOf[TimestampType])
+ new JDecimal(lit.value.toString)
+ }
+ n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0
+ }
+
/**
* Intersected results of two ranges. This is only for two overlapped ranges.
* The outputs are the intersected min/max values.
@@ -113,4 +128,5 @@ object Range {
DateTimeUtils.toJavaTimestamp(n.max.longValue()))
}
}
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
new file mode 100644
index 0000000000..f139c9e28c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+/**
+ * In this test suite, we test predicates containing the following operators:
+ * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN
+ */
+class FilterEstimationSuite extends StatsEstimationTestBase {
+
+ // Suppose our test table has 10 rows and 6 columns.
+ // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
+ // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+ val arInt = AttributeReference("cint", IntegerType)()
+ val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ // Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
+ val dMin = Date.valueOf("2017-01-01")
+ val dMax = Date.valueOf("2017-01-10")
+ val arDate = AttributeReference("cdate", DateType)()
+ val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through
+ // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours).
+ val tsMin = Timestamp.valueOf("2017-01-01 01:00:00")
+ val tsMax = Timestamp.valueOf("2017-01-01 10:00:00")
+ val arTimestamp = AttributeReference("ctimestamp", TimestampType)()
+ val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax),
+ nullCount = 0, avgLen = 8, maxLen = 8)
+
+ // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
+ val decMin = new java.math.BigDecimal("0.200000000000000000")
+ val decMax = new java.math.BigDecimal("0.800000000000000000")
+ val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
+ val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
+ nullCount = 0, avgLen = 8, maxLen = 8)
+
+ // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
+ val arDouble = AttributeReference("cdouble", DoubleType)()
+ val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
+ nullCount = 0, avgLen = 8, maxLen = 8)
+
+ // Sixth column cstring has 10 String values:
+ // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
+ val arString = AttributeReference("cstring", StringType)()
+ val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2)
+
+ test("cint = 2") {
+ validateEstimatedStats(
+ arInt,
+ Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(1L)
+ )
+ }
+
+ test("cint = 0") {
+ // This is an out-of-range case since 0 is outside the range [min, max]
+ validateEstimatedStats(
+ arInt,
+ Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(0L)
+ )
+ }
+
+ test("cint < 3") {
+ validateEstimatedStats(
+ arInt,
+ Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(3L)
+ )
+ }
+
+ test("cint < 0") {
+ // This is a corner case since literal 0 is smaller than min.
+ validateEstimatedStats(
+ arInt,
+ Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(0L)
+ )
+ }
+
+ test("cint <= 3") {
+ validateEstimatedStats(
+ arInt,
+ Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(3L)
+ )
+ }
+
+ test("cint > 6") {
+ validateEstimatedStats(
+ arInt,
+ Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(5L)
+ )
+ }
+
+ test("cint > 10") {
+ // This is a corner case since max value is 10.
+ validateEstimatedStats(
+ arInt,
+ Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(0L)
+ )
+ }
+
+ test("cint >= 6") {
+ validateEstimatedStats(
+ arInt,
+ Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(5L)
+ )
+ }
+
+ test("cint IS NULL") {
+ validateEstimatedStats(
+ arInt,
+ Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 0, min = None, max = None,
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(0L)
+ )
+ }
+
+ test("cint IS NOT NULL") {
+ validateEstimatedStats(
+ arInt,
+ Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(10L)
+ )
+ }
+
+ test("cint > 3 AND cint <= 6") {
+ val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6)))
+ validateEstimatedStats(
+ arInt,
+ Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(4L)
+ )
+ }
+
+ test("cint = 3 OR cint = 6") {
+ val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6)))
+ validateEstimatedStats(
+ arInt,
+ Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(2L)
+ )
+ }
+
+ test("cint IN (3, 4, 5)") {
+ validateEstimatedStats(
+ arInt,
+ Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(3L)
+ )
+ }
+
+ test("cint NOT IN (3, 4, 5)") {
+ validateEstimatedStats(
+ arInt,
+ Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(7L)
+ )
+ }
+
+ test("cdate = cast('2017-01-02' AS DATE)") {
+ val d20170102 = Date.valueOf("2017-01-02")
+ validateEstimatedStats(
+ arDate,
+ Filter(EqualTo(arDate, Literal(d20170102)),
+ childStatsTestPlan(Seq(arDate), 10L)),
+ ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(1L)
+ )
+ }
+
+ test("cdate < cast('2017-01-03' AS DATE)") {
+ val d20170103 = Date.valueOf("2017-01-03")
+ validateEstimatedStats(
+ arDate,
+ Filter(LessThan(arDate, Literal(d20170103)),
+ childStatsTestPlan(Seq(arDate), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(3L)
+ )
+ }
+
+ test("""cdate IN ( cast('2017-01-03' AS DATE),
+ cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
+ val d20170103 = Date.valueOf("2017-01-03")
+ val d20170104 = Date.valueOf("2017-01-04")
+ val d20170105 = Date.valueOf("2017-01-05")
+ validateEstimatedStats(
+ arDate,
+ Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
+ childStatsTestPlan(Seq(arDate), 10L)),
+ ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(3L)
+ )
+ }
+
+ test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") {
+ val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00")
+ validateEstimatedStats(
+ arTimestamp,
+ Filter(EqualTo(arTimestamp, Literal(ts2017010102)),
+ childStatsTestPlan(Seq(arTimestamp), 10L)),
+ ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102),
+ nullCount = 0, avgLen = 8, maxLen = 8),
+ Some(1L)
+ )
+ }
+
+ test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") {
+ val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00")
+ validateEstimatedStats(
+ arTimestamp,
+ Filter(LessThan(arTimestamp, Literal(ts2017010103)),
+ childStatsTestPlan(Seq(arTimestamp), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103),
+ nullCount = 0, avgLen = 8, maxLen = 8),
+ Some(3L)
+ )
+ }
+
+ test("cdecimal = 0.400000000000000000") {
+ val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
+ validateEstimatedStats(
+ arDecimal,
+ Filter(EqualTo(arDecimal, Literal(dec_0_40)),
+ childStatsTestPlan(Seq(arDecimal), 4L)),
+ ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
+ nullCount = 0, avgLen = 8, maxLen = 8),
+ Some(1L)
+ )
+ }
+
+ test("cdecimal < 0.60 ") {
+ val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
+ validateEstimatedStats(
+ arDecimal,
+ Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))),
+ childStatsTestPlan(Seq(arDecimal), 4L)),
+ ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
+ nullCount = 0, avgLen = 8, maxLen = 8),
+ Some(3L)
+ )
+ }
+
+ test("cdouble < 3.0") {
+ validateEstimatedStats(
+ arDouble,
+ Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
+ nullCount = 0, avgLen = 8, maxLen = 8),
+ Some(3L)
+ )
+ }
+
+ test("cstring = 'A2'") {
+ validateEstimatedStats(
+ arString,
+ Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
+ ColumnStat(distinctCount = 1, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2),
+ Some(1L)
+ )
+ }
+
+ // There is no min/max statistics for String type. We estimate 10 rows returned.
+ test("cstring < 'A2'") {
+ validateEstimatedStats(
+ arString,
+ Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
+ ColumnStat(distinctCount = 10, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2),
+ Some(10L)
+ )
+ }
+
+ // This is a corner test case. We want to test if we can handle the case when the number of
+ // valid values in IN clause is greater than the number of distinct values for a given column.
+ // For example, column has only 2 distinct values 1 and 6.
+ // The predicate is: column IN (1, 2, 3, 4, 5).
+ test("cint IN (1, 2, 3, 4, 5)") {
+ val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+ val cornerChildStatsTestplan = StatsTestPlan(
+ outputList = Seq(arInt),
+ rowCount = 2L,
+ attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt))
+ )
+ validateEstimatedStats(
+ arInt,
+ Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
+ ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ Some(2L)
+ )
+ }
+
+ private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
+ StatsTestPlan(
+ outputList = outList,
+ rowCount = tableRowCount,
+ attributeStats = AttributeMap(Seq(
+ arInt -> childColStatInt,
+ arDate -> childColStatDate,
+ arTimestamp -> childColStatTimestamp,
+ arDecimal -> childColStatDecimal,
+ arDouble -> childColStatDouble,
+ arString -> childColStatString
+ ))
+ )
+ }
+
+ private def validateEstimatedStats(
+ ar: AttributeReference,
+ filterNode: Filter,
+ expectedColStats: ColumnStat,
+ rowCount: Option[BigInt] = None)
+ : Unit = {
+
+ val expectedRowCount: BigInt = rowCount.getOrElse(0L)
+ val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
+ val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats)
+
+ val filteredStats = filterNode.stats(conf)
+ assert(filteredStats.sizeInBytes == expectedSizeInBytes)
+ assert(filteredStats.rowCount == rowCount)
+ ar.dataType match {
+ case DecimalType() =>
+ // Due to the internal transformation for DecimalType within engine, the new min/max
+ // in ColumnStat may have a different structure even it contains the right values.
+ // We convert them to Java BigDecimal values so that we can compare the entire object.
+ val generatedColumnStats = filteredStats.attributeStats(ar)
+ val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString)
+ val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString)
+ val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax))
+ assert(outputColStats == expectedColStats)
+ case _ =>
+ // For all other SQL types, we compare the entire object directly.
+ assert(filteredStats.attributeStats(ar) == expectedColStats)
+ }
+ }
+
+}