From d7e43b613aeb6d85b8522d7c16c6e0807843c964 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 23 Feb 2017 20:18:21 -0800 Subject: [SPARK-17075][SQL] implemented filter estimation ## What changes were proposed in this pull request? We traverse predicate and evaluate the logical expressions to compute the selectivity of a FILTER operator. ## How was this patch tested? We add a new test suite to test various logical operators. Author: Ron Hu Closes #16395 from ron8hu/filterSelectivity. --- .../plans/logical/basicLogicalOperators.scala | 10 +- .../logical/statsEstimation/FilterEstimation.scala | 511 +++++++++++++++++++++ .../plans/logical/statsEstimation/Range.scala | 16 + .../statsEstimation/FilterEstimationSuite.scala | 403 ++++++++++++++++ 4 files changed, 939 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ce1c55dc08..ccebae3cc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -129,6 +129,14 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } + + override def computeStats(conf: CatalystConf): Statistics = { + if (conf.cboEnabled) { + FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + } else { + super.computeStats(conf) + } + } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala new file mode 100644 index 0000000000..fcc607a610 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.{HashSet, Map} +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + + /** + * We use a mutable colStats because we need to update the corresponding ColumnStat + * for a column after we apply a predicate condition. For example, column c has + * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), + * we need to set the column's [min, max] value to [40, 100] after we evaluate the + * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. + */ + private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ + def estimate: Option[Statistics] = { + // We first copy child node's statistics and then modify it based on filter selectivity. + val stats: Statistics = plan.child.stats(catalystConf) + if (stats.rowCount.isEmpty) return None + + // save a mutable copy of colStats so that we can later change it recursively + mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + + // estimate selectivity of this filter predicate + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { + case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => 1.0 + } + + // attributeStats has mapping Attribute-to-ColumnStat. + // mutableColStats has mapping ExprId-to-ColumnStat. + // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) + // copy mutableColStats contents to an immutable AttributeMap. + val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = + mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) + val newColStats = AttributeMap(mutableAttributeStats.toSeq) + + val filteredRowCount: BigInt = + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes = + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) + + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + attributeStats = newColStats)) + } + + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + + condition match { + case And(cond1, cond2) => + (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + match { + case (Some(p1), Some(p2)) => Some(p1 * p2) + case (Some(p1), None) => Some(p1) + case (None, Some(p2)) => Some(p2) + case (None, None) => None + } + + case Or(cond1, cond2) => + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update = false) + val percent2 = calculateFilterSelectivity(cond2, update = false) + (percent1, percent2) match { + case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) + case (Some(p1), None) => Some(1.0) + case (None, Some(p2)) => Some(1.0) + case (None, None) => None + } + + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => None + } + + case _ => + calculateSingleCondition(condition, update) + } + } + + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param condition a single logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return Option[Double] value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + condition match { + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + + // EqualTo does not care about the order + case op @ EqualTo(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ EqualTo(l: Literal, ar: AttributeReference) => + evaluateBinary(op, ar, l, update) + + case op @ LessThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThan(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op @ GreaterThan(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThan(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + evaluateBinary(op, ar, l, update) + case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + case In(ar: AttributeReference, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ar: AttributeReference, set) => + evaluateInSet(ar, set, update) + + case IsNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = true, update) + + case IsNotNull(ar: AttributeReference) => + evaluateIsNull(ar, isNull = false, update) + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + None + } + } + + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param attrRef an AttributeReference (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics collected for a given column. + */ + def evaluateIsNull( + attrRef: AttributeReference, + isNull: Boolean, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + val aColStat = mutableColStats(attrRef.exprId) + val rowCountValue = plan.child.stats(catalystConf).rowCount.get + val nullPercent: BigDecimal = + if (rowCountValue == 0) 0.0 + else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + + if (update) { + val newStats = + if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) + else aColStat.copy(nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + + val percent = + if (isNull) { + nullPercent.toDouble + } + else { + /** ISNOTNULL(column) */ + 1.0 - nullPercent.toDouble + } + + Some(percent) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column or wrong value. + */ + def evaluateBinary( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + op match { + case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + case _ => + attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attrRef, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) + None + } + } + } + + /** + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ + def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { + attrDataType match { + case DateType => + Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) + case TimestampType => + Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case StringType | BinaryType => + None + case _ => + Some(litValue) + } + } + + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateEqualTo( + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) + val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) + + if (inBoundary) { + + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + val newStats = aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + Some(1.0 / ndv.toDouble) + } else { + Some(0.0) + } + + } + + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column. + */ + + def evaluateInSet( + attrRef: AttributeReference, + hSet: Set[Any], + update: Boolean) + : Option[Double] = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return None + } + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val aType = attrRef.dataType + var newNdv: Long = 0 + + // use [min, max] to filter the original hSet + aType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] + + // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. + // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. + val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) + val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) + // We use hSetBigdecToAnyMap to help us find the original hSet value. + val hSetBigdecToAnyMap: Map[BigDecimal, Any] = + hSet.map(e => BigDecimal(e.toString) -> e).toMap + + if (validQuerySet.isEmpty) { + return Some(0.0) + } + + // Need to save new min/max using the external type value of the literal + val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) + val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => + newNdv = math.min(hSet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } + } + + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. + Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForNumeric( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Option[Double] = { + + var percent = 1.0 + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + // determine the overlapping degree between predicate range and column's range + val literalValueBD = BigDecimal(literal.value.toString) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case _: LessThan => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case _: LessThanOrEqual => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case _: GreaterThan => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case _: GreaterThanOrEqual => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // this is partial overlap case + var newMax = aColStat.max + var newMin = aColStat.min + var newNdv = ndv + val literalToDouble = literalValueBD.toDouble + val maxToDouble = BigDecimal(statsRange.max).toDouble + val minToDouble = BigDecimal(statsRange.min).toDouble + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + // For ease of computation, we convert all relevant numeric values to Double. + percent = op match { + case _: LessThan => + (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: LessThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble + else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case _: GreaterThan => + (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + case _: GreaterThanOrEqual => + if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble + else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + } + + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + + if (update) { + op match { + case _: GreaterThan => newMin = newValue + case _: GreaterThanOrEqual => newMin = newValue + case _: LessThan => newMax = newValue + case _: LessThanOrEqual => newMax = newValue + } + + newNdv = math.max(math.round(ndv.toDouble * percent), 1) + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + } + + Some(percent) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 5aa6b9353b..4557114532 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} @@ -57,6 +58,20 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } + def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { + case _: DefaultRange => true + case _: NullRange => false + case n: NumericRange => + val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { + if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + } else { + assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || + lit.dataType.isInstanceOf[TimestampType]) + new JDecimal(lit.value.toString) + } + n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 + } + /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. @@ -113,4 +128,5 @@ object Range { DateTimeUtils.toJavaTimestamp(n.max.longValue())) } } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala new file mode 100644 index 0000000000..f139c9e28c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * In this test suite, we test predicates containing the following operators: + * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN + */ +class FilterEstimationSuite extends StatsEstimationTestBase { + + // Suppose our test table has 10 rows and 6 columns. + // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val arInt = AttributeReference("cint", IntegerType)() + val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + val dMin = Date.valueOf("2017-01-01") + val dMax = Date.valueOf("2017-01-10") + val arDate = AttributeReference("cdate", DateType)() + val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through + // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). + val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") + val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") + val arTimestamp = AttributeReference("ctimestamp", TimestampType)() + val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + val decMin = new java.math.BigDecimal("0.200000000000000000") + val decMax = new java.math.BigDecimal("0.800000000000000000") + val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arDouble = AttributeReference("cdouble", DoubleType)() + val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + nullCount = 0, avgLen = 8, maxLen = 8) + + // Sixth column cstring has 10 String values: + // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" + val arString = AttributeReference("cstring", StringType)() + val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2) + + test("cint = 2") { + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + + test("cint = 0") { + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + arInt, + Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint < 3") { + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint < 0") { + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + arInt, + Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint <= 3") { + validateEstimatedStats( + arInt, + Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint > 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) + } + + test("cint > 10") { + // This is a corner case since max value is 10. + validateEstimatedStats( + arInt, + Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint >= 6") { + validateEstimatedStats( + arInt, + Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) + } + + test("cint IS NULL") { + validateEstimatedStats( + arInt, + Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 0, min = None, max = None, + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } + + test("cint IS NOT NULL") { + validateEstimatedStats( + arInt, + Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(10L) + ) + } + + test("cint > 3 AND cint <= 6") { + val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(4L) + ) + } + + test("cint = 3 OR cint = 6") { + val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + validateEstimatedStats( + arInt, + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) + } + + test("cint IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cint NOT IN (3, 4, 5)") { + validateEstimatedStats( + arInt, + Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(7L) + ) + } + + test("cdate = cast('2017-01-02' AS DATE)") { + val d20170102 = Date.valueOf("2017-01-02") + validateEstimatedStats( + arDate, + Filter(EqualTo(arDate, Literal(d20170102)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + + test("cdate < cast('2017-01-03' AS DATE)") { + val d20170103 = Date.valueOf("2017-01-03") + validateEstimatedStats( + arDate, + Filter(LessThan(arDate, Literal(d20170103)), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("""cdate IN ( cast('2017-01-03' AS DATE), + cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { + val d20170103 = Date.valueOf("2017-01-03") + val d20170104 = Date.valueOf("2017-01-04") + val d20170105 = Date.valueOf("2017-01-05") + validateEstimatedStats( + arDate, + Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(arDate), 10L)), + ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { + val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") + validateEstimatedStats( + arTimestamp, + Filter(EqualTo(arTimestamp, Literal(ts2017010102)), + childStatsTestPlan(Seq(arTimestamp), 10L)), + ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { + val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") + validateEstimatedStats( + arTimestamp, + Filter(LessThan(arTimestamp, Literal(ts2017010103)), + childStatsTestPlan(Seq(arTimestamp), 10L)), + ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cdecimal = 0.400000000000000000") { + val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + validateEstimatedStats( + arDecimal, + Filter(EqualTo(arDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("cdecimal < 0.60 ") { + val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + validateEstimatedStats( + arDecimal, + Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cdouble < 3.0") { + validateEstimatedStats( + arDouble, + Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), + ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + + test("cstring = 'A2'") { + validateEstimatedStats( + arString, + Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(1L) + ) + } + + // There is no min/max statistics for String type. We estimate 10 rows returned. + test("cstring < 'A2'") { + validateEstimatedStats( + arString, + Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), + ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(10L) + ) + } + + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). + test("cint IN (1, 2, 3, 4, 5)") { + val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildStatsTestplan = StatsTestPlan( + outputList = Seq(arInt), + rowCount = 2L, + attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + ) + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) + } + + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { + StatsTestPlan( + outputList = outList, + rowCount = tableRowCount, + attributeStats = AttributeMap(Seq( + arInt -> childColStatInt, + arDate -> childColStatDate, + arTimestamp -> childColStatTimestamp, + arDecimal -> childColStatDecimal, + arDouble -> childColStatDouble, + arString -> childColStatString + )) + ) + } + + private def validateEstimatedStats( + ar: AttributeReference, + filterNode: Filter, + expectedColStats: ColumnStat, + rowCount: Option[BigInt] = None) + : Unit = { + + val expectedRowCount: BigInt = rowCount.getOrElse(0L) + val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) + val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) + + val filteredStats = filterNode.stats(conf) + assert(filteredStats.sizeInBytes == expectedSizeInBytes) + assert(filteredStats.rowCount == rowCount) + ar.dataType match { + case DecimalType() => + // Due to the internal transformation for DecimalType within engine, the new min/max + // in ColumnStat may have a different structure even it contains the right values. + // We convert them to Java BigDecimal values so that we can compare the entire object. + val generatedColumnStats = filteredStats.attributeStats(ar) + val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) + val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) + val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) + assert(outputColStats == expectedColStats) + case _ => + // For all other SQL types, we compare the entire object directly. + assert(filteredStats.attributeStats(ar) == expectedColStats) + } + } + +} -- cgit v1.2.3