From 932196d9e30453e0827ee3cd8a81cb306b7a24d9 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 6 Mar 2017 23:53:53 -0800 Subject: [SPARK-17075][SQL][FOLLOWUP] fix filter estimation issues ## What changes were proposed in this pull request? 1. support boolean type in binary expression estimation. 2. deal with compound Not conditions. 3. avoid convert BigInt/BigDecimal directly to double unless it's within range (0, 1). 4. reorganize test code. ## How was this patch tested? modify related test cases. Author: wangzhenhua Author: Zhenhua Wang Closes #17148 from wzhfy/fixFilter. --- .../logical/statsEstimation/FilterEstimation.scala | 174 +++++---- .../statsEstimation/FilterEstimationSuite.scala | 397 +++++++++++---------- 2 files changed, 297 insertions(+), 274 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0c928832d7..b10785b05d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // save a mutable copy of colStats so that we can later change it recursively + // Save a mutable copy of colStats so that we can later change it recursively. colStatsMap.setInitValues(childStats.attributeStats) - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } + // Estimate selectivity of this filter predicate, and update column stats if needed. + // For not-supported condition, set filter selectivity to a conservative estimate 100% + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) - val newColStats = colStatsMap.toColumnStats + val newColStats = if (filterSelectivity == 0) { + // The output is empty, we don't need to keep column stats. + AttributeMap[ColumnStat](Nil) + } else { + colStatsMap.toColumnStats + } val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) @@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * Returns a percentage of rows meeting a condition in Filter node. + * If it's a single condition, we calculate the percentage directly. + * If it's a compound condition, it is decomposed into multiple single conditions linked with + * AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. + * For logical OR and NOT conditions, we do not update stats after a condition estimation. * * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column @@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update) - val percent2 = calculateFilterSelectivity(cond2, update) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) - case _ => calculateSingleCondition(condition, update) + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None + } + + case _ => + calculateSingleCondition(condition, update) } } @@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = if (isNull) { - nullPercent.toDouble + nullPercent } else { - 1.0 - nullPercent.toDouble + 1.0 - nullPercent } - Some(percent) + Some(percent.toDouble) } /** @@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo attr: Attribute, literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + attr.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType | DateType | TimestampType | BooleanType => evaluateBinaryForNumeric(op, attr, literal, update) case StringType | BinaryType => // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) None - case _ => - // TODO: support boolean type. - None } } @@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * + * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo; + * if the literal is null, the condition will be changed to IsNull after optimization. + * So we don't need specific logic for EqualNullSafe here. + * * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo colStatsMap(attr) = newStats } - Some(1.0 / ndv.toDouble) + Some((1.0 / BigDecimal(ndv)).toDouble) } else { Some(0.0) } @@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) } /** * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. + * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * * @param op a binary comparison operator uch as =, <, <=, >, >= * @param attr an Attribute (or a column) @@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo literal: Literal, update: Boolean): Option[Double] = { - var percent = 1.0 val colStat = colStatsMap(attr) - val statsRange = - Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val max = BigDecimal(statsRange.max) + val min = BigDecimal(statsRange.min) + val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) + val numericLiteral = if (literal.dataType == BooleanType) { + if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) + } else { + BigDecimal(literal.value.toString) + } val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + (numericLiteral <= min, numericLiteral > max) case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + (numericLiteral < min, numericLiteral >= max) case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + (numericLiteral >= max, numericLiteral < min) case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + (numericLiteral > max, numericLiteral <= min) } + var percent = BigDecimal(1.0) if (noOverlap) { percent = 0.0 } else if (completeOverlap) { percent = 1.0 } else { - // this is partial overlap case - val literalDouble = literalValueBD.toDouble - val maxDouble = BigDecimal(statsRange.max).toDouble - val minDouble = BigDecimal(statsRange.min).toDouble - + // This is the partial overlap case: // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. + assert(max > min) percent = op match { case _: LessThan => - (literalDouble - minDouble) / (maxDouble - minDouble) + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv } else { - (literalDouble - minDouble) / (maxDouble - minDouble) + (numericLiteral - min) / (max - min) } case _: GreaterThan => - (maxDouble - literalDouble) / (maxDouble - minDouble) + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == max) { + 1.0 / ndv } else { - (maxDouble - literalDouble) / (maxDouble - minDouble) + (max - numericLiteral) / (max - min) } } @@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val newValue = convertBoundValue(attr.dataType, literal.value) var newMax = colStat.max var newMin = colStat.min + var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdv < 1) newNdv = 1 + op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue + case _: GreaterThan | _: GreaterThanOrEqual => + // If new ndv is 1, then new max must be equal to new min. + newMin = if (newNdv == 1) newMax else newValue + case _: LessThan | _: LessThanOrEqual => + newMax = if (newNdv == 1) newMin else newValue } - val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = + colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) colStatsMap(attr) = newStats } } - Some(percent) + Some(percent.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 8be74ced7b..4691913c8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val arInt = AttributeReference("cint", IntegerType)() - val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val attrInt = AttributeReference("cint", IntegerType)() + val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) // only 2 values - val arBool = AttributeReference("cbool", BooleanType)() - val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + val attrBool = AttributeReference("cbool", BooleanType)() + val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") - val arDate = AttributeReference("cdate", DateType)() - val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + val attrDate = AttributeReference("cdate", DateType)() + val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", DoubleType)() - val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + val attrDouble = AttributeReference("cdouble", DoubleType)() + val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) // Sixth column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" - val arString = AttributeReference("cstring", StringType)() - val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + val attrString = AttributeReference("cstring", StringType)() + val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + val attributeMap = AttributeMap(Seq( + attrInt -> colStatInt, + attrBool -> colStatBool, + attrDate -> colStatDate, + attrDecimal -> colStatDecimal, + attrDouble -> colStatDouble, + attrString -> colStatString)) + test("cint = 2") { validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( - arInt, - Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint = 0") { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint < 3") { validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint < 0") { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint <= 3") { validateEstimatedStats( - arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint >= 6") { validateEstimatedStats( - arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint IS NULL") { validateEstimatedStats( - arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint IS NOT NULL") { validateEstimatedStats( - arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 10) + Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) } test("cint > 3 AND cint <= 6") { - val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - 4) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) } test("cint = 3 OR cint = 6") { - val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + test("Not(cint > 3 AND cint <= 6)") { + val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 6) + } + + test("Not(cint <= 3 OR cint > 6)") { + val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 5) + } + + test("Not(cint = 3 AND cstring < 'A8')") { + val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 10) + } + + test("Not(cint = 3 OR cstring < 'A8')") { + val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 7) + Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 7) } test("cbool = true") { validateEstimatedStats( - arBool, - Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 5) + Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool > false") { - // bool comparison is not supported yet, so stats remain same. validateEstimatedStats( - arBool, - Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 10) + Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrDate, Literal(d20170102)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal(d20170103)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrDate, Literal(d20170103)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( - arDate, - Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( - arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8), - 1) + Filter(EqualTo(attrDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( - arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( - arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), + Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( - arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 1) + Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 1) } - // There is no min/max statistics for String type. We estimate 10 rows returned. - test("cstring < 'A2'") { + test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( - arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 10) + Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 10) } - // This is a corner test case. We want to test if we can handle the case when the number of - // valid values in IN clause is greater than the number of distinct values for a given column. - // For example, column has only 2 distinct values 1 and 6. - // The predicate is: column IN (1, 2, 3, 4, 5). test("cint IN (1, 2, 3, 4, 5)") { + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4) val cornerChildStatsTestplan = StatsTestPlan( - outputList = Seq(arInt), + outputList = Seq(attrInt), rowCount = 2L, - attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arBool -> childColStatBool, - arDate -> childColStatDate, - arDecimal -> childColStatDecimal, - arDouble -> childColStatDouble, - arString -> childColStatString - )) - ) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) } private def validateEstimatedStats( - ar: AttributeReference, filterNode: Filter, - expectedColStats: ColumnStat, - rowCount: Int): Unit = { - - val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) - - val filteredStats = filterNode.stats(conf) - assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount.get == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) - - // If the filter has a binary operator (including those nested inside - // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the - // operator, and then check again. - val rewrittenFilter = filterNode transformExpressionsDown { - case EqualTo(ar: AttributeReference, l: Literal) => - EqualTo(l, ar) - - case LessThan(ar: AttributeReference, l: Literal) => - GreaterThan(l, ar) - case LessThanOrEqual(ar: AttributeReference, l: Literal) => - GreaterThanOrEqual(l, ar) - - case GreaterThan(ar: AttributeReference, l: Literal) => - LessThan(l, ar) - case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - LessThanOrEqual(l, ar) + expectedColStats: Seq[(Attribute, ColumnStat)], + expectedRowCount: Int): Unit = { + + // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the + // sides of the attribute and the literal, reverse the operator, and then check again. + val swappedFilter = filterNode transformExpressionsDown { + case EqualTo(attr: Attribute, l: Literal) => + EqualTo(l, attr) + + case LessThan(attr: Attribute, l: Literal) => + GreaterThan(l, attr) + case LessThanOrEqual(attr: Attribute, l: Literal) => + GreaterThanOrEqual(l, attr) + + case GreaterThan(attr: Attribute, l: Literal) => + LessThan(l, attr) + case GreaterThanOrEqual(attr: Attribute, l: Literal) => + LessThanOrEqual(l, attr) + } + + val testFilters = if (swappedFilter != filterNode) { + Seq(swappedFilter, filterNode) + } else { + Seq(filterNode) } - if (rewrittenFilter != filterNode) { - validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + testFilters.foreach { filter => + val expectedAttributeMap = AttributeMap(expectedColStats) + val expectedStats = Statistics( + sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttributeMap) + assert(filter.stats(conf) == expectedStats) } } } -- cgit v1.2.3