diff options
author | Wenchen Fan <wenchen@databricks.com> | 2017-02-25 23:01:44 -0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-02-25 23:01:44 -0800 |
commit | 89608cf26226e28f331a4695fee89394d0d38ea0 (patch) | |
tree | e86920946dc18f2ce51fdb7f702f513d3b111114 | |
parent | 6ab60542e8e803b1d91371a92f4aaef6a64106f6 (diff) | |
download | spark-89608cf26226e28f331a4695fee89394d0d38ea0.tar.gz spark-89608cf26226e28f331a4695fee89394d0d38ea0.tar.bz2 spark-89608cf26226e28f331a4695fee89394d0d38ea0.zip |
[SPARK-17075][SQL][FOLLOWUP] fix some minor issues and clean up the code
## What changes were proposed in this pull request?
This is a follow-up of https://github.com/apache/spark/pull/16395. It fixes some code style issues, naming issues, some missing cases in pattern match, etc.
## How was this patch tested?
existing tests.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #17065 from cloud-fan/follow-up.
5 files changed, 296 insertions, 337 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 1504a52279..9f4a0f2b70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -28,7 +28,7 @@ object AttributeMap { } } -class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) +class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 37f29ba68a..0c928832d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.{HashSet, Map} +import scala.collection.immutable.HashSet import scala.collection.mutable import org.apache.spark.internal.Logging @@ -31,15 +29,16 @@ import org.apache.spark.sql.types._ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { + private val childStats = plan.child.stats(catalystConf) + /** - * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, column c has - * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), - * we need to set the column's [min, max] value to [40, 100] after we evaluate the - * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * We will update the corresponding ColumnStats for a column after we apply a predicate condition. + * For example, column c has [min, max] value as [0, 100]. In a range condition such as + * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we + * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50] * after we evaluate the second condition c <= 50. */ - private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + private val colStatsMap = new ColumnStatsMap /** * Returns an option of Statistics for a Filter logical plan node. @@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return Option[Statistics] When there is no statistics collected, it returns None. */ def estimate: Option[Statistics] = { - // We first copy child node's statistics and then modify it based on filter selectivity. - val stats: Statistics = plan.child.stats(catalystConf) - if (stats.rowCount.isEmpty) return None + if (childStats.rowCount.isEmpty) return None // save a mutable copy of colStats so that we can later change it recursively - mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) + colStatsMap.setInitValues(childStats.attributeStats) // estimate selectivity of this filter predicate val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { @@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => 1.0 } - // attributeStats has mapping Attribute-to-ColumnStat. - // mutableColStats has mapping ExprId-to-ColumnStat. - // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - // copy mutableColStats contents to an immutable AttributeMap. - val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) - val newColStats = AttributeMap(mutableAttributeStats.toSeq) + val newColStats = colStatsMap.toColumnStats val filteredRowCount: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes = + EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val filteredSizeInBytes: BigInt = EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), + Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) } @@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a double value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { - condition match { case And(cond1, cond2) => - (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) - match { + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update) + val percent2 = calculateFilterSelectivity(cond2, update) + (percent1, percent2) match { case (Some(p1), Some(p2)) => Some(p1 * p2) case (Some(p1), None) => Some(p1) case (None, Some(p2)) => Some(p2) @@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case None => None } - case _ => - calculateSingleCondition(condition, update) + case _ => calculateSingleCondition(condition, update) } } @@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition a single logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return Option[Double] value to show the percentage of rows meeting a given condition. + * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { @@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. - // EqualTo does not care about the order - case op @ EqualTo(ar: AttributeReference, l: Literal) => - evaluateBinary(op, ar, l, update) - case op @ EqualTo(l: Literal, ar: AttributeReference) => - evaluateBinary(op, ar, l, update) + // EqualTo/EqualNullSafe does not care about the order + case op @ Equality(ar: Attribute, l: Literal) => + evaluateEquality(ar, l, update) + case op @ Equality(l: Literal, ar: Attribute) => + evaluateEquality(ar, l, update) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case op @ LessThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThan(l: Literal, ar: AttributeReference) => + case op @ LessThan(l: Literal, ar: Attribute) => evaluateBinary(GreaterThan(ar, l), ar, l, update) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ LessThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ LessThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case op @ GreaterThan(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThan(l: Literal, ar: AttributeReference) => + case op @ GreaterThan(l: Literal, ar: Attribute) => evaluateBinary(LessThan(ar, l), ar, l, update) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) => evaluateBinary(op, ar, l, update) - case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => + case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: AttributeReference, expList) + case In(ar: Attribute, expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. @@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: AttributeReference, set) => + case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = true, update) + case IsNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: AttributeReference) => - evaluateIsNull(ar, isNull = false, update) + case IsNotNull(ar: Attribute) => + evaluateNullCheck(ar, isNull = false, update) case _ => // TODO: it's difficult to support string operators without advanced statistics. @@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics collected for a given column. */ - def evaluateIsNull( - attrRef: AttributeReference, + def evaluateNullCheck( + attr: Attribute, isNull: Boolean, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.stats(catalystConf).rowCount.get - val nullPercent: BigDecimal = - if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) + val colStat = colStatsMap(attr) + val rowCountValue = childStats.rowCount.get + val nullPercent: BigDecimal = if (rowCountValue == 0) { + 0 + } else { + BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + } if (update) { - val newStats = - if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) - else aColStat.copy(nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) + val newStats = if (isNull) { + colStat.copy(distinctCount = 0, min = None, max = None) + } else { + colStat.copy(nullCount = 0) + } + colStatsMap(attr) = newStats } - val percent = - if (isNull) { - nullPercent.toDouble - } - else { - /** ISNOTNULL(column) */ - 1.0 - nullPercent.toDouble - } + val percent = if (isNull) { + nullPercent.toDouble + } else { + 1.0 - nullPercent.toDouble + } Some(percent) } @@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinary( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return None - } - - op match { - case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + update: Boolean): Option[Double] = { + attr.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attr, literal, update) + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) + None case _ => - attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, attrRef, literal, update) - case StringType | BinaryType => - // TODO: It is difficult to support other binary comparisons for String/Binary - // type without min/max and advanced statistics like histogram. - logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) - None - } + // TODO: support boolean type. + None } } @@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) case TimestampType => Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case _: DecimalType => + Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) case StringType | BinaryType => None case _ => @@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateEqualTo( - attrRef: AttributeReference, + def evaluateEquality( + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) - val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) - - if (inBoundary) { - + val statsRange = Range(colStat.min, colStat.max, attr.dataType) + if (statsRange.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate. // Set distinctCount to 1. Set nullCount to 0. // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - val newStats = aColStat.copy(distinctCount = 1, min = newValue, + val newValue = convertBoundValue(attr.dataType, literal.value) + val newStats = colStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } Some(1.0 / ndv.toDouble) @@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. * - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param hSet a set of literal values * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateInSet( - attrRef: AttributeReference, + attr: Attribute, hSet: Set[Any], - update: Boolean) - : Option[Double] = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) return None } - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val aType = attrRef.dataType - var newNdv: Long = 0 + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + val dataType = attr.dataType + var newNdv = ndv // use [min, max] to filter the original hSet - aType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - - // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. - // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. - val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) - val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) - // We use hSetBigdecToAnyMap to help us find the original hSet value. - val hSetBigdecToAnyMap: Map[BigDecimal, Any] = - hSet.map(e => BigDecimal(e.toString) -> e).toMap + dataType match { + case _: NumericType | BooleanType | DateType | TimestampType => + val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val validQuerySet = hSet.filter { v => + v != null && statsRange.contains(Literal(v, dataType)) + } if (validQuerySet.isEmpty) { return Some(0.0) } // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) - val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + val newMax = convertBoundValue( + attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) + val newMin = convertBoundValue( + attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => - newNdv = math.min(hSet.size.toLong, ndv.longValue()) + newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + colStatsMap(attr) = newStats } } @@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * This method evaluate expression for Numeric columns only. * * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) + * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions @@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def evaluateBinaryForNumeric( op: BinaryComparison, - attrRef: AttributeReference, + attr: Attribute, literal: Literal, - update: Boolean) - : Option[Double] = { + update: Boolean): Option[Double] = { var percent = 1.0 - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount + val colStat = colStatsMap(attr) val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] // determine the overlapping degree between predicate range and column's range val literalValueBD = BigDecimal(literal.value.toString) @@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo percent = 1.0 } else { // this is partial overlap case - var newMax = aColStat.max - var newMin = aColStat.min - var newNdv = ndv - val literalToDouble = literalValueBD.toDouble - val maxToDouble = BigDecimal(statsRange.max).toDouble - val minToDouble = BigDecimal(statsRange.min).toDouble + val literalDouble = literalValueBD.toDouble + val maxDouble = BigDecimal(statsRange.max).toDouble + val minDouble = BigDecimal(statsRange.min).toDouble // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. // For ease of computation, we convert all relevant numeric values to Double. percent = op match { case _: LessThan => - (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + (literalDouble - minDouble) / (maxDouble - minDouble) case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble - else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.min)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (literalDouble - minDouble) / (maxDouble - minDouble) + } case _: GreaterThan => - (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + (maxDouble - literalDouble) / (maxDouble - minDouble) case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble - else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + if (literalValueBD == BigDecimal(statsRange.max)) { + 1.0 / colStat.distinctCount.toDouble + } else { + (maxDouble - literalDouble) / (maxDouble - minDouble) + } } - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attrRef.dataType, literal.value) - if (update) { + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attr.dataType, literal.value) + var newMax = colStat.max + var newMin = colStat.min op match { case _: GreaterThan => newMin = newValue case _: GreaterThanOrEqual => newMin = newValue @@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: LessThanOrEqual => newMax = newValue } - newNdv = math.max(math.round(ndv.toDouble * percent), 1) - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) + val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.exprId -> newStats) + colStatsMap(attr) = newStats } } @@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } } + +class ColumnStatsMap { + private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty + + def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = { + baseMap.clear() + baseMap ++= colStats.baseMap + } + + def contains(a: Attribute): Boolean = baseMap.contains(a.exprId) + + def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2 + + def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats) + + def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 982a5a8bb8..9782c0bb0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging case _ if !rowCountsExist(conf, join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val selectivity = joinSelectivity(joinKeyPairs) @@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (innerJoinedRows == 0) { + } else if (selectivity == 0) { joinType match { - // For outer joins, if the inner join part is empty, the number of output rows is the + // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side // keep unchanged, while column stats of join keys from the other side should be updated // based on added null values. @@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } case _ => Nil } + } else if (selectivity == 1) { + // Cartesian product, just propagate the original column stats + inputAttrStats.toSeq } else { val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { @@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - isBroadcastable = false)) + attributeStats = outputAttrStats)) case _ => // When there is no equi-join condition, we do estimation like cartesian product. @@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), rowCount = Some(outputRows), - attributeStats = inputAttrStats, - isBroadcastable = false)) + attributeStats = inputAttrStats)) } // scalastyle:off @@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } if (ndvDenom < 0) { - // There isn't join keys or column stats for any of the join key pairs, we do estimation like - // cartesian product. + // We can't find any join key pairs with column stats, estimate it as cartesian join. 1 } else if (ndvDenom == 0) { // One of the join key pairs is disjoint, thus the two sides of join is disjoint. @@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging /** * Propagate or update column stats for output attributes. - * 1. For cartesian product, all values are preserved, so there's no need to change column stats. - * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update - * distinct count of other attributes based on output rows after join. */ private def updateAttrStats( outputRows: BigInt, @@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - if (outputRows == leftRows * rightRows) { - // Cartesian product, just propagate the original column stats - attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) - } else { - val leftRatio = - if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) - val rightRatio = - if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) - attributes.foreach { a => - // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats += a -> joinKeyStats(a) + } else { + val leftRatio = if (leftRows != 0) { + BigDecimal(outputRows) / BigDecimal(leftRows) + } else { + BigDecimal(0) + } + val rightRatio = if (rightRows != 0) { + BigDecimal(outputRows) / BigDecimal(rightRows) } else { - val oldColStat = oldAttrStats(a) - val oldNdv = oldColStat.distinctCount - // We only change (scale down) the number of distinct values if the number of rows - // decreases after join, because join won't produce new values even if the number of - // rows increases. - val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { - ceil(BigDecimal(oldNdv) * leftRatio) - } else if (join.right.outputSet.contains(a) && rightRatio < 1) { - ceil(BigDecimal(oldNdv) * rightRatio) - } else { - oldNdv - } - // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + BigDecimal(0) } + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) } + } outputAttrStats } @@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) - intersectedStats.put(leftKey, - leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) - intersectedStats.put(rightKey, - rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) + val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + intersectedStats.put(leftKey, newStats) + intersectedStats.put(rightKey, newStats) } AttributeMap(intersectedStats.toSeq) } @@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), rowCount = Some(outputRows), - attributeStats = leftStats.attributeStats, - isBroadcastable = false)) + attributeStats = leftStats.attributeStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 4557114532..3d13967cb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} /** Value range of a column. */ -trait Range +trait Range { + def contains(l: Literal): Boolean +} /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range +case class NumericRange(min: JDecimal, max: JDecimal) extends Range { + override def contains(l: Literal): Boolean = { + val decimal = l.dataType match { + case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + case _ => new JDecimal(l.value.toString) + } + min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + } +} /** * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range +class DefaultRange extends Range { + override def contains(l: Literal): Boolean = true +} /** This is for columns with only null values. */ -class NullRange extends Range +class NullRange extends Range { + override def contains(l: Literal): Boolean = false +} object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { @@ -58,20 +72,6 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } - def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { - case _: DefaultRange => true - case _: NullRange => false - case n: NumericRange => - val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { - if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - } else { - assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || - lit.dataType.isInstanceOf[TimestampType]) - new JDecimal(lit.value.toString) - } - n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 - } - /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index f5e306f9e5..8be74ced7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.statsEstimation -import java.sql.{Date, Timestamp} +import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) + // only 2 values + val arBool = AttributeReference("cbool", BooleanType)() + val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1) + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") @@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through - // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). - val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") - val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") - val arTimestamp = AttributeReference("ctimestamp", TimestampType)() - val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), - nullCount = 0, avgLen = 8, maxLen = 8) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") @@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) + } + + test("cint <=> 2") { + validateEstimatedStats( + arInt, + Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + 1) } test("cint = 0") { @@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint < 3") { @@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint < 0") { @@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint <= 3") { @@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint > 6") { @@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint > 10") { @@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint >= 6") { @@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(5L) - ) + 5) } test("cint IS NULL") { @@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), - Some(0L) - ) + 0) } test("cint IS NOT NULL") { @@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(10L) - ) + 10) } test("cint > 3 AND cint <= 6") { @@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), - Some(4L) - ) + 4) } test("cint = 3 OR cint = 6") { @@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } test("cint IN (3, 4, 5)") { @@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("cint NOT IN (3, 4, 5)") { @@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - Some(7L) - ) + 7) + } + + test("cbool = true") { + validateEstimatedStats( + arBool, + Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 5) + } + + test("cbool > false") { + // bool comparison is not supported yet, so stats remain same. + validateEstimatedStats( + arBool, + Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), + ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1), + 10) } test("cdate = cast('2017-01-02' AS DATE)") { @@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) + 1) } test("cdate < cast('2017-01-03' AS DATE)") { @@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) + 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) - } - - test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { - val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") - validateEstimatedStats( - arTimestamp, - Filter(EqualTo(arTimestamp, Literal(ts2017010102)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) - } - - test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { - val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") - validateEstimatedStats( - arTimestamp, - Filter(LessThan(arTimestamp, Literal(ts2017010103)), - childStatsTestPlan(Seq(arTimestamp), 10L)), - ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdecimal = 0.400000000000000000") { @@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) + 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + Filter(LessThan(arDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cdouble < 3.0") { @@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) + 3) } test("cstring = 'A2'") { @@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(1L) - ) + 1) } // There is no min/max statistics for String type. We estimate 10 rows returned. @@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), - Some(10L) - ) + 10) } // This is a corner test case. We want to test if we can handle the case when the number of @@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - Some(2L) - ) + 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { @@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = tableRowCount, attributeStats = AttributeMap(Seq( arInt -> childColStatInt, + arBool -> childColStatBool, arDate -> childColStatDate, - arTimestamp -> childColStatTimestamp, arDecimal -> childColStatDecimal, arDouble -> childColStatDouble, arString -> childColStatString @@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ar: AttributeReference, filterNode: Filter, expectedColStats: ColumnStat, - rowCount: Option[BigInt] = None) - : Unit = { + rowCount: Int): Unit = { - val expectedRowCount: BigInt = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) + val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) val filteredStats = filterNode.stats(conf) assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount == rowCount) - ar.dataType match { - case DecimalType() => - // Due to the internal transformation for DecimalType within engine, the new min/max - // in ColumnStat may have a different structure even it contains the right values. - // We convert them to Java BigDecimal values so that we can compare the entire object. - val generatedColumnStats = filteredStats.attributeStats(ar) - val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) - val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) - val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) - assert(outputColStats == expectedColStats) - case _ => - // For all other SQL types, we compare the entire object directly. - assert(filteredStats.attributeStats(ar) == expectedColStats) - } + assert(filteredStats.rowCount.get == rowCount) + assert(filteredStats.attributeStats(ar) == expectedColStats) // If the filter has a binary operator (including those nested inside // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the // operator, and then check again. val rewrittenFilter = filterNode transformExpressionsDown { - case op @ EqualTo(ar: AttributeReference, l: Literal) => + case EqualTo(ar: AttributeReference, l: Literal) => EqualTo(l, ar) - case op @ LessThan(ar: AttributeReference, l: Literal) => + case LessThan(ar: AttributeReference, l: Literal) => GreaterThan(l, ar) - case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => + case LessThanOrEqual(ar: AttributeReference, l: Literal) => GreaterThanOrEqual(l, ar) - case op @ GreaterThan(ar: AttributeReference, l: Literal) => + case GreaterThan(ar: AttributeReference, l: Literal) => LessThan(l, ar) - case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => + case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => LessThanOrEqual(l, ar) } |