aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-02-25 23:01:44 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-25 23:01:44 -0800
commit89608cf26226e28f331a4695fee89394d0d38ea0 (patch)
treee86920946dc18f2ce51fdb7f702f513d3b111114 /sql/catalyst
parent6ab60542e8e803b1d91371a92f4aaef6a64106f6 (diff)
downloadspark-89608cf26226e28f331a4695fee89394d0d38ea0.tar.gz
spark-89608cf26226e28f331a4695fee89394d0d38ea0.tar.bz2
spark-89608cf26226e28f331a4695fee89394d0d38ea0.zip
[SPARK-17075][SQL][FOLLOWUP] fix some minor issues and clean up the code
## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/16395. It fixes some code style issues, naming issues, some missing cases in pattern match, etc. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #17065 from cloud-fan/follow-up.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala330
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala91
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala36
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala174
5 files changed, 296 insertions, 337 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 1504a52279..9f4a0f2b70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -28,7 +28,7 @@ object AttributeMap {
}
}
-class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
+class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
extends Map[Attribute, A] with Serializable {
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 37f29ba68a..0c928832d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
-import java.sql.{Date, Timestamp}
-
-import scala.collection.immutable.{HashSet, Map}
+import scala.collection.immutable.HashSet
import scala.collection.mutable
import org.apache.spark.internal.Logging
@@ -31,15 +29,16 @@ import org.apache.spark.sql.types._
case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging {
+ private val childStats = plan.child.stats(catalystConf)
+
/**
- * We use a mutable colStats because we need to update the corresponding ColumnStat
- * for a column after we apply a predicate condition. For example, column c has
- * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50),
- * we need to set the column's [min, max] value to [40, 100] after we evaluate the
- * first condition c > 40. We need to set the column's [min, max] value to [40, 50]
+ * We will update the corresponding ColumnStats for a column after we apply a predicate condition.
+ * For example, column c has [min, max] value as [0, 100]. In a range condition such as
+ * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we
+ * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50]
* after we evaluate the second condition c <= 50.
*/
- private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty
+ private val colStatsMap = new ColumnStatsMap
/**
* Returns an option of Statistics for a Filter logical plan node.
@@ -51,12 +50,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @return Option[Statistics] When there is no statistics collected, it returns None.
*/
def estimate: Option[Statistics] = {
- // We first copy child node's statistics and then modify it based on filter selectivity.
- val stats: Statistics = plan.child.stats(catalystConf)
- if (stats.rowCount.isEmpty) return None
+ if (childStats.rowCount.isEmpty) return None
// save a mutable copy of colStats so that we can later change it recursively
- mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*)
+ colStatsMap.setInitValues(childStats.attributeStats)
// estimate selectivity of this filter predicate
val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
@@ -65,22 +62,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case None => 1.0
}
- // attributeStats has mapping Attribute-to-ColumnStat.
- // mutableColStats has mapping ExprId-to-ColumnStat.
- // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat
- val expridToAttrMap: Map[ExprId, Attribute] =
- stats.attributeStats.map(kv => (kv._1.exprId, kv._1))
- // copy mutableColStats contents to an immutable AttributeMap.
- val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] =
- mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2)
- val newColStats = AttributeMap(mutableAttributeStats.toSeq)
+ val newColStats = colStatsMap.toColumnStats
val filteredRowCount: BigInt =
- EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity)
- val filteredSizeInBytes =
+ EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
+ val filteredSizeInBytes: BigInt =
EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
- Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
+ Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
attributeStats = newColStats))
}
@@ -95,15 +84,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @param condition the compound logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
* for subsequent conditions
- * @return a double value to show the percentage of rows meeting a given condition.
+ * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
-
condition match {
case And(cond1, cond2) =>
- (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update))
- match {
+ // For ease of debugging, we compute percent1 and percent2 in 2 statements.
+ val percent1 = calculateFilterSelectivity(cond1, update)
+ val percent2 = calculateFilterSelectivity(cond2, update)
+ (percent1, percent2) match {
case (Some(p1), Some(p2)) => Some(p1 * p2)
case (Some(p1), None) => Some(p1)
case (None, Some(p2)) => Some(p2)
@@ -127,8 +117,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case None => None
}
- case _ =>
- calculateSingleCondition(condition, update)
+ case _ => calculateSingleCondition(condition, update)
}
}
@@ -140,7 +129,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* @param condition a single logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
* for subsequent conditions
- * @return Option[Double] value to show the percentage of rows meeting a given condition.
+ * @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
@@ -148,33 +137,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
// For evaluateBinary method, we assume the literal on the right side of an operator.
// So we will change the order if not.
- // EqualTo does not care about the order
- case op @ EqualTo(ar: AttributeReference, l: Literal) =>
- evaluateBinary(op, ar, l, update)
- case op @ EqualTo(l: Literal, ar: AttributeReference) =>
- evaluateBinary(op, ar, l, update)
+ // EqualTo/EqualNullSafe does not care about the order
+ case op @ Equality(ar: Attribute, l: Literal) =>
+ evaluateEquality(ar, l, update)
+ case op @ Equality(l: Literal, ar: Attribute) =>
+ evaluateEquality(ar, l, update)
- case op @ LessThan(ar: AttributeReference, l: Literal) =>
+ case op @ LessThan(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
- case op @ LessThan(l: Literal, ar: AttributeReference) =>
+ case op @ LessThan(l: Literal, ar: Attribute) =>
evaluateBinary(GreaterThan(ar, l), ar, l, update)
- case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
+ case op @ LessThanOrEqual(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
- case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) =>
+ case op @ LessThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update)
- case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
+ case op @ GreaterThan(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
- case op @ GreaterThan(l: Literal, ar: AttributeReference) =>
+ case op @ GreaterThan(l: Literal, ar: Attribute) =>
evaluateBinary(LessThan(ar, l), ar, l, update)
- case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
+ case op @ GreaterThanOrEqual(ar: Attribute, l: Literal) =>
evaluateBinary(op, ar, l, update)
- case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) =>
+ case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)
- case In(ar: AttributeReference, expList)
+ case In(ar: Attribute, expList)
if expList.forall(e => e.isInstanceOf[Literal]) =>
// Expression [In (value, seq[Literal])] will be replaced with optimized version
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
@@ -182,14 +171,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
val hSet = expList.map(e => e.eval())
evaluateInSet(ar, HashSet() ++ hSet, update)
- case InSet(ar: AttributeReference, set) =>
+ case InSet(ar: Attribute, set) =>
evaluateInSet(ar, set, update)
- case IsNull(ar: AttributeReference) =>
- evaluateIsNull(ar, isNull = true, update)
+ case IsNull(ar: Attribute) =>
+ evaluateNullCheck(ar, isNull = true, update)
- case IsNotNull(ar: AttributeReference) =>
- evaluateIsNull(ar, isNull = false, update)
+ case IsNotNull(ar: Attribute) =>
+ evaluateNullCheck(ar, isNull = false, update)
case _ =>
// TODO: it's difficult to support string operators without advanced statistics.
@@ -203,44 +192,43 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
/**
* Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition.
*
- * @param attrRef an AttributeReference (or a column)
+ * @param attr an Attribute (or a column)
* @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
* @return an optional double value to show the percentage of rows meeting a given condition
* It returns None if no statistics collected for a given column.
*/
- def evaluateIsNull(
- attrRef: AttributeReference,
+ def evaluateNullCheck(
+ attr: Attribute,
isNull: Boolean,
- update: Boolean)
- : Option[Double] = {
- if (!mutableColStats.contains(attrRef.exprId)) {
- logDebug("[CBO] No statistics for " + attrRef)
+ update: Boolean): Option[Double] = {
+ if (!colStatsMap.contains(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
return None
}
- val aColStat = mutableColStats(attrRef.exprId)
- val rowCountValue = plan.child.stats(catalystConf).rowCount.get
- val nullPercent: BigDecimal =
- if (rowCountValue == 0) 0.0
- else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue)
+ val colStat = colStatsMap(attr)
+ val rowCountValue = childStats.rowCount.get
+ val nullPercent: BigDecimal = if (rowCountValue == 0) {
+ 0
+ } else {
+ BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)
+ }
if (update) {
- val newStats =
- if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None)
- else aColStat.copy(nullCount = 0)
-
- mutableColStats += (attrRef.exprId -> newStats)
+ val newStats = if (isNull) {
+ colStat.copy(distinctCount = 0, min = None, max = None)
+ } else {
+ colStat.copy(nullCount = 0)
+ }
+ colStatsMap(attr) = newStats
}
- val percent =
- if (isNull) {
- nullPercent.toDouble
- }
- else {
- /** ISNOTNULL(column) */
- 1.0 - nullPercent.toDouble
- }
+ val percent = if (isNull) {
+ nullPercent.toDouble
+ } else {
+ 1.0 - nullPercent.toDouble
+ }
Some(percent)
}
@@ -249,7 +237,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting a binary comparison expression.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
- * @param attrRef an AttributeReference (or a column)
+ * @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@@ -258,27 +246,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateBinary(
op: BinaryComparison,
- attrRef: AttributeReference,
+ attr: Attribute,
literal: Literal,
- update: Boolean)
- : Option[Double] = {
- if (!mutableColStats.contains(attrRef.exprId)) {
- logDebug("[CBO] No statistics for " + attrRef)
- return None
- }
-
- op match {
- case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update)
+ update: Boolean): Option[Double] = {
+ attr.dataType match {
+ case _: NumericType | DateType | TimestampType =>
+ evaluateBinaryForNumeric(op, attr, literal, update)
+ case StringType | BinaryType =>
+ // TODO: It is difficult to support other binary comparisons for String/Binary
+ // type without min/max and advanced statistics like histogram.
+ logDebug("[CBO] No range comparison statistics for String/Binary type " + attr)
+ None
case _ =>
- attrRef.dataType match {
- case _: NumericType | DateType | TimestampType =>
- evaluateBinaryForNumeric(op, attrRef, literal, update)
- case StringType | BinaryType =>
- // TODO: It is difficult to support other binary comparisons for String/Binary
- // type without min/max and advanced statistics like histogram.
- logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef)
- None
- }
+ // TODO: support boolean type.
+ None
}
}
@@ -297,6 +278,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
case TimestampType =>
Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
+ case _: DecimalType =>
+ Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
case StringType | BinaryType =>
None
case _ =>
@@ -308,37 +291,36 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
- * @param attrRef an AttributeReference (or a column)
+ * @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
* @return an optional double value to show the percentage of rows meeting a given condition
*/
- def evaluateEqualTo(
- attrRef: AttributeReference,
+ def evaluateEquality(
+ attr: Attribute,
literal: Literal,
- update: Boolean)
- : Option[Double] = {
-
- val aColStat = mutableColStats(attrRef.exprId)
- val ndv = aColStat.distinctCount
+ update: Boolean): Option[Double] = {
+ if (!colStatsMap.contains(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
+ return None
+ }
+ val colStat = colStatsMap(attr)
+ val ndv = colStat.distinctCount
// decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type.
// Hence, we assume it is in boundary for binary/string type.
- val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType)
- val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal)
-
- if (inBoundary) {
-
+ val statsRange = Range(colStat.min, colStat.max, attr.dataType)
+ if (statsRange.contains(literal)) {
if (update) {
// We update ColumnStat structure after apply this equality predicate.
// Set distinctCount to 1. Set nullCount to 0.
// Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attrRef.dataType, literal.value)
- val newStats = aColStat.copy(distinctCount = 1, min = newValue,
+ val newValue = convertBoundValue(attr.dataType, literal.value)
+ val newStats = colStat.copy(distinctCount = 1, min = newValue,
max = newValue, nullCount = 0)
- mutableColStats += (attrRef.exprId -> newStats)
+ colStatsMap(attr) = newStats
}
Some(1.0 / ndv.toDouble)
@@ -352,7 +334,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting "IN" operator expression.
* This method evaluates the equality predicate for all data types.
*
- * @param attrRef an AttributeReference (or a column)
+ * @param attr an Attribute (or a column)
* @param hSet a set of literal values
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@@ -361,57 +343,52 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateInSet(
- attrRef: AttributeReference,
+ attr: Attribute,
hSet: Set[Any],
- update: Boolean)
- : Option[Double] = {
- if (!mutableColStats.contains(attrRef.exprId)) {
- logDebug("[CBO] No statistics for " + attrRef)
+ update: Boolean): Option[Double] = {
+ if (!colStatsMap.contains(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
return None
}
- val aColStat = mutableColStats(attrRef.exprId)
- val ndv = aColStat.distinctCount
- val aType = attrRef.dataType
- var newNdv: Long = 0
+ val colStat = colStatsMap(attr)
+ val ndv = colStat.distinctCount
+ val dataType = attr.dataType
+ var newNdv = ndv
// use [min, max] to filter the original hSet
- aType match {
- case _: NumericType | DateType | TimestampType =>
- val statsRange =
- Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange]
-
- // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal.
- // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec.
- val hSetBigdec = hSet.map(e => BigDecimal(e.toString))
- val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max)
- // We use hSetBigdecToAnyMap to help us find the original hSet value.
- val hSetBigdecToAnyMap: Map[BigDecimal, Any] =
- hSet.map(e => BigDecimal(e.toString) -> e).toMap
+ dataType match {
+ case _: NumericType | BooleanType | DateType | TimestampType =>
+ val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange]
+ val validQuerySet = hSet.filter { v =>
+ v != null && statsRange.contains(Literal(v, dataType))
+ }
if (validQuerySet.isEmpty) {
return Some(0.0)
}
// Need to save new min/max using the external type value of the literal
- val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max))
- val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min))
+ val newMax = convertBoundValue(
+ attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
+ val newMin = convertBoundValue(
+ attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
// newNdv should not be greater than the old ndv. For example, column has only 2 values
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
- newNdv = math.min(validQuerySet.size.toLong, ndv.longValue())
+ newNdv = ndv.min(BigInt(validQuerySet.size))
if (update) {
- val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+ val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
max = newMax, nullCount = 0)
- mutableColStats += (attrRef.exprId -> newStats)
+ colStatsMap(attr) = newStats
}
// We assume the whole set since there is no min/max information for String/Binary type
case StringType | BinaryType =>
- newNdv = math.min(hSet.size.toLong, ndv.longValue())
+ newNdv = ndv.min(BigInt(hSet.size))
if (update) {
- val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0)
- mutableColStats += (attrRef.exprId -> newStats)
+ val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0)
+ colStatsMap(attr) = newStats
}
}
@@ -425,7 +402,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* This method evaluate expression for Numeric columns only.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
- * @param attrRef an AttributeReference (or a column)
+ * @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
@@ -433,16 +410,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
*/
def evaluateBinaryForNumeric(
op: BinaryComparison,
- attrRef: AttributeReference,
+ attr: Attribute,
literal: Literal,
- update: Boolean)
- : Option[Double] = {
+ update: Boolean): Option[Double] = {
var percent = 1.0
- val aColStat = mutableColStats(attrRef.exprId)
- val ndv = aColStat.distinctCount
+ val colStat = colStatsMap(attr)
val statsRange =
- Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange]
+ Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
// determine the overlapping degree between predicate range and column's range
val literalValueBD = BigDecimal(literal.value.toString)
@@ -463,33 +438,37 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
percent = 1.0
} else {
// this is partial overlap case
- var newMax = aColStat.max
- var newMin = aColStat.min
- var newNdv = ndv
- val literalToDouble = literalValueBD.toDouble
- val maxToDouble = BigDecimal(statsRange.max).toDouble
- val minToDouble = BigDecimal(statsRange.min).toDouble
+ val literalDouble = literalValueBD.toDouble
+ val maxDouble = BigDecimal(statsRange.max).toDouble
+ val minDouble = BigDecimal(statsRange.min).toDouble
// Without advanced statistics like histogram, we assume uniform data distribution.
// We just prorate the adjusted range over the initial range to compute filter selectivity.
// For ease of computation, we convert all relevant numeric values to Double.
percent = op match {
case _: LessThan =>
- (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+ (literalDouble - minDouble) / (maxDouble - minDouble)
case _: LessThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble
- else (literalToDouble - minToDouble) / (maxToDouble - minToDouble)
+ if (literalValueBD == BigDecimal(statsRange.min)) {
+ 1.0 / colStat.distinctCount.toDouble
+ } else {
+ (literalDouble - minDouble) / (maxDouble - minDouble)
+ }
case _: GreaterThan =>
- (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+ (maxDouble - literalDouble) / (maxDouble - minDouble)
case _: GreaterThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble
- else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble)
+ if (literalValueBD == BigDecimal(statsRange.max)) {
+ 1.0 / colStat.distinctCount.toDouble
+ } else {
+ (maxDouble - literalDouble) / (maxDouble - minDouble)
+ }
}
- // Need to save new min/max using the external type value of the literal
- val newValue = convertBoundValue(attrRef.dataType, literal.value)
-
if (update) {
+ // Need to save new min/max using the external type value of the literal
+ val newValue = convertBoundValue(attr.dataType, literal.value)
+ var newMax = colStat.max
+ var newMin = colStat.min
op match {
case _: GreaterThan => newMin = newValue
case _: GreaterThanOrEqual => newMin = newValue
@@ -497,11 +476,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case _: LessThanOrEqual => newMax = newValue
}
- newNdv = math.max(math.round(ndv.toDouble * percent), 1)
- val newStats = aColStat.copy(distinctCount = newNdv, min = newMin,
+ val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1)
+ val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
max = newMax, nullCount = 0)
- mutableColStats += (attrRef.exprId -> newStats)
+ colStatsMap(attr) = newStats
}
}
@@ -509,3 +488,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
}
+
+class ColumnStatsMap {
+ private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty
+
+ def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = {
+ baseMap.clear()
+ baseMap ++= colStats.baseMap
+ }
+
+ def contains(a: Attribute): Boolean = baseMap.contains(a.exprId)
+
+ def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2
+
+ def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats)
+
+ def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index 982a5a8bb8..9782c0bb0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -59,7 +59,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
case _ if !rowCountsExist(conf, join.left, join.right) =>
None
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val selectivity = joinSelectivity(joinKeyPairs)
@@ -94,9 +94,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
// The output is empty, we don't need to keep column stats.
Nil
- } else if (innerJoinedRows == 0) {
+ } else if (selectivity == 0) {
joinType match {
- // For outer joins, if the inner join part is empty, the number of output rows is the
+ // For outer joins, if the join selectivity is 0, the number of output rows is the
// same as that of the outer side. And column stats of join keys from the outer side
// keep unchanged, while column stats of join keys from the other side should be updated
// based on added null values.
@@ -116,6 +116,9 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
}
case _ => Nil
}
+ } else if (selectivity == 1) {
+ // Cartesian product, just propagate the original column stats
+ inputAttrStats.toSeq
} else {
val joinKeyStats = getIntersectedStats(joinKeyPairs)
join.joinType match {
@@ -138,8 +141,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
- attributeStats = outputAttrStats,
- isBroadcastable = false))
+ attributeStats = outputAttrStats))
case _ =>
// When there is no equi-join condition, we do estimation like cartesian product.
@@ -150,8 +152,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats),
rowCount = Some(outputRows),
- attributeStats = inputAttrStats,
- isBroadcastable = false))
+ attributeStats = inputAttrStats))
}
// scalastyle:off
@@ -189,8 +190,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
}
if (ndvDenom < 0) {
- // There isn't join keys or column stats for any of the join key pairs, we do estimation like
- // cartesian product.
+ // We can't find any join key pairs with column stats, estimate it as cartesian join.
1
} else if (ndvDenom == 0) {
// One of the join key pairs is disjoint, thus the two sides of join is disjoint.
@@ -202,9 +202,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
/**
* Propagate or update column stats for output attributes.
- * 1. For cartesian product, all values are preserved, so there's no need to change column stats.
- * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update
- * distinct count of other attributes based on output rows after join.
*/
private def updateAttrStats(
outputRows: BigInt,
@@ -214,35 +211,38 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
val leftRows = leftStats.rowCount.get
val rightRows = rightStats.rowCount.get
- if (outputRows == leftRows * rightRows) {
- // Cartesian product, just propagate the original column stats
- attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a))
- } else {
- val leftRatio =
- if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0)
- val rightRatio =
- if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0)
- attributes.foreach { a =>
- // check if this attribute is a join key
- if (joinKeyStats.contains(a)) {
- outputAttrStats += a -> joinKeyStats(a)
+
+ attributes.foreach { a =>
+ // check if this attribute is a join key
+ if (joinKeyStats.contains(a)) {
+ outputAttrStats += a -> joinKeyStats(a)
+ } else {
+ val leftRatio = if (leftRows != 0) {
+ BigDecimal(outputRows) / BigDecimal(leftRows)
+ } else {
+ BigDecimal(0)
+ }
+ val rightRatio = if (rightRows != 0) {
+ BigDecimal(outputRows) / BigDecimal(rightRows)
} else {
- val oldColStat = oldAttrStats(a)
- val oldNdv = oldColStat.distinctCount
- // We only change (scale down) the number of distinct values if the number of rows
- // decreases after join, because join won't produce new values even if the number of
- // rows increases.
- val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
- ceil(BigDecimal(oldNdv) * leftRatio)
- } else if (join.right.outputSet.contains(a) && rightRatio < 1) {
- ceil(BigDecimal(oldNdv) * rightRatio)
- } else {
- oldNdv
- }
- // TODO: support nullCount updates for specific outer joins
- outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
+ BigDecimal(0)
}
+ val oldColStat = oldAttrStats(a)
+ val oldNdv = oldColStat.distinctCount
+ // We only change (scale down) the number of distinct values if the number of rows
+ // decreases after join, because join won't produce new values even if the number of
+ // rows increases.
+ val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
+ ceil(BigDecimal(oldNdv) * leftRatio)
+ } else if (join.right.outputSet.contains(a) && rightRatio < 1) {
+ ceil(BigDecimal(oldNdv) * rightRatio)
+ } else {
+ oldNdv
+ }
+ // TODO: support nullCount updates for specific outer joins
+ outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
}
+
}
outputAttrStats
}
@@ -263,12 +263,14 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging
// Update intersected column stats
assert(leftKey.dataType.sameType(rightKey.dataType))
- val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
+ val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType)
- intersectedStats.put(leftKey,
- leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
- intersectedStats.put(rightKey,
- rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0))
+ val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
+ val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
+ val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
+
+ intersectedStats.put(leftKey, newStats)
+ intersectedStats.put(rightKey, newStats)
}
AttributeMap(intersectedStats.toSeq)
}
@@ -298,8 +300,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) {
Some(Statistics(
sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats),
rowCount = Some(outputRows),
- attributeStats = leftStats.attributeStats,
- isBroadcastable = false))
+ attributeStats = leftStats.attributeStats))
} else {
None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 4557114532..3d13967cb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -26,19 +26,33 @@ import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
/** Value range of a column. */
-trait Range
+trait Range {
+ def contains(l: Literal): Boolean
+}
/** For simplicity we use decimal to unify operations of numeric ranges. */
-case class NumericRange(min: JDecimal, max: JDecimal) extends Range
+case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
+ override def contains(l: Literal): Boolean = {
+ val decimal = l.dataType match {
+ case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
+ case _ => new JDecimal(l.value.toString)
+ }
+ min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
+ }
+}
/**
* This version of Spark does not have min/max for binary/string types, we define their default
* behaviors by this class.
*/
-class DefaultRange extends Range
+class DefaultRange extends Range {
+ override def contains(l: Literal): Boolean = true
+}
/** This is for columns with only null values. */
-class NullRange extends Range
+class NullRange extends Range {
+ override def contains(l: Literal): Boolean = false
+}
object Range {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
@@ -58,20 +72,6 @@ object Range {
n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
}
- def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match {
- case _: DefaultRange => true
- case _: NullRange => false
- case n: NumericRange =>
- val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) {
- if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
- } else {
- assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] ||
- lit.dataType.isInstanceOf[TimestampType])
- new JDecimal(lit.value.toString)
- }
- n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0
- }
-
/**
* Intersected results of two ranges. This is only for two overlapped ranges.
* The outputs are the intersected min/max values.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index f5e306f9e5..8be74ced7b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.catalyst.statsEstimation
-import java.sql.{Date, Timestamp}
+import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
/**
@@ -38,6 +37,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
+ // only 2 values
+ val arBool = AttributeReference("cbool", BooleanType)()
+ val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)
+
// Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = Date.valueOf("2017-01-01")
val dMax = Date.valueOf("2017-01-10")
@@ -45,14 +49,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
- // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through
- // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours).
- val tsMin = Timestamp.valueOf("2017-01-01 01:00:00")
- val tsMax = Timestamp.valueOf("2017-01-01 10:00:00")
- val arTimestamp = AttributeReference("ctimestamp", TimestampType)()
- val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax),
- nullCount = 0, avgLen = 8, maxLen = 8)
-
// Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = new java.math.BigDecimal("0.200000000000000000")
val decMax = new java.math.BigDecimal("0.800000000000000000")
@@ -77,8 +73,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(1L)
- )
+ 1)
+ }
+
+ test("cint <=> 2") {
+ validateEstimatedStats(
+ arInt,
+ Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
+ ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ 1)
}
test("cint = 0") {
@@ -88,8 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(0L)
- )
+ 0)
}
test("cint < 3") {
@@ -98,8 +101,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(3L)
- )
+ 3)
}
test("cint < 0") {
@@ -109,8 +111,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(0L)
- )
+ 0)
}
test("cint <= 3") {
@@ -119,8 +120,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(3L)
- )
+ 3)
}
test("cint > 6") {
@@ -129,8 +129,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(5L)
- )
+ 5)
}
test("cint > 10") {
@@ -140,8 +139,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(0L)
- )
+ 0)
}
test("cint >= 6") {
@@ -150,8 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(5L)
- )
+ 5)
}
test("cint IS NULL") {
@@ -160,8 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 0, min = None, max = None,
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(0L)
- )
+ 0)
}
test("cint IS NOT NULL") {
@@ -170,8 +166,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(10L)
- )
+ 10)
}
test("cint > 3 AND cint <= 6") {
@@ -181,8 +176,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(4L)
- )
+ 4)
}
test("cint = 3 OR cint = 6") {
@@ -192,8 +186,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(2L)
- )
+ 2)
}
test("cint IN (3, 4, 5)") {
@@ -202,8 +195,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(3L)
- )
+ 3)
}
test("cint NOT IN (3, 4, 5)") {
@@ -212,8 +204,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(7L)
- )
+ 7)
+ }
+
+ test("cbool = true") {
+ validateEstimatedStats(
+ arBool,
+ Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)),
+ ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1),
+ 5)
+ }
+
+ test("cbool > false") {
+ // bool comparison is not supported yet, so stats remain same.
+ validateEstimatedStats(
+ arBool,
+ Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)),
+ ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1),
+ 10)
}
test("cdate = cast('2017-01-02' AS DATE)") {
@@ -224,8 +234,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(1L)
- )
+ 1)
}
test("cdate < cast('2017-01-03' AS DATE)") {
@@ -236,8 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(3L)
- )
+ 3)
}
test("""cdate IN ( cast('2017-01-03' AS DATE),
@@ -251,32 +259,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDate), 10L)),
ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(3L)
- )
- }
-
- test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") {
- val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00")
- validateEstimatedStats(
- arTimestamp,
- Filter(EqualTo(arTimestamp, Literal(ts2017010102)),
- childStatsTestPlan(Seq(arTimestamp), 10L)),
- ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102),
- nullCount = 0, avgLen = 8, maxLen = 8),
- Some(1L)
- )
- }
-
- test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") {
- val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00")
- validateEstimatedStats(
- arTimestamp,
- Filter(LessThan(arTimestamp, Literal(ts2017010103)),
- childStatsTestPlan(Seq(arTimestamp), 10L)),
- ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103),
- nullCount = 0, avgLen = 8, maxLen = 8),
- Some(3L)
- )
+ 3)
}
test("cdecimal = 0.400000000000000000") {
@@ -287,20 +270,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
childStatsTestPlan(Seq(arDecimal), 4L)),
ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
nullCount = 0, avgLen = 8, maxLen = 8),
- Some(1L)
- )
+ 1)
}
test("cdecimal < 0.60 ") {
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
validateEstimatedStats(
arDecimal,
- Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))),
+ Filter(LessThan(arDecimal, Literal(dec_0_60)),
childStatsTestPlan(Seq(arDecimal), 4L)),
ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
nullCount = 0, avgLen = 8, maxLen = 8),
- Some(3L)
- )
+ 3)
}
test("cdouble < 3.0") {
@@ -309,8 +290,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
nullCount = 0, avgLen = 8, maxLen = 8),
- Some(3L)
- )
+ 3)
}
test("cstring = 'A2'") {
@@ -319,8 +299,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
ColumnStat(distinctCount = 1, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2),
- Some(1L)
- )
+ 1)
}
// There is no min/max statistics for String type. We estimate 10 rows returned.
@@ -330,8 +309,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
ColumnStat(distinctCount = 10, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2),
- Some(10L)
- )
+ 10)
}
// This is a corner test case. We want to test if we can handle the case when the number of
@@ -351,8 +329,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
nullCount = 0, avgLen = 4, maxLen = 4),
- Some(2L)
- )
+ 2)
}
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
@@ -361,8 +338,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
rowCount = tableRowCount,
attributeStats = AttributeMap(Seq(
arInt -> childColStatInt,
+ arBool -> childColStatBool,
arDate -> childColStatDate,
- arTimestamp -> childColStatTimestamp,
arDecimal -> childColStatDecimal,
arDouble -> childColStatDouble,
arString -> childColStatString
@@ -374,46 +351,31 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
ar: AttributeReference,
filterNode: Filter,
expectedColStats: ColumnStat,
- rowCount: Option[BigInt] = None)
- : Unit = {
+ rowCount: Int): Unit = {
- val expectedRowCount: BigInt = rowCount.getOrElse(0L)
val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
- val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats)
+ val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats)
val filteredStats = filterNode.stats(conf)
assert(filteredStats.sizeInBytes == expectedSizeInBytes)
- assert(filteredStats.rowCount == rowCount)
- ar.dataType match {
- case DecimalType() =>
- // Due to the internal transformation for DecimalType within engine, the new min/max
- // in ColumnStat may have a different structure even it contains the right values.
- // We convert them to Java BigDecimal values so that we can compare the entire object.
- val generatedColumnStats = filteredStats.attributeStats(ar)
- val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString)
- val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString)
- val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax))
- assert(outputColStats == expectedColStats)
- case _ =>
- // For all other SQL types, we compare the entire object directly.
- assert(filteredStats.attributeStats(ar) == expectedColStats)
- }
+ assert(filteredStats.rowCount.get == rowCount)
+ assert(filteredStats.attributeStats(ar) == expectedColStats)
// If the filter has a binary operator (including those nested inside
// AND/OR/NOT), swap the sides of the attribte and the literal, reverse the
// operator, and then check again.
val rewrittenFilter = filterNode transformExpressionsDown {
- case op @ EqualTo(ar: AttributeReference, l: Literal) =>
+ case EqualTo(ar: AttributeReference, l: Literal) =>
EqualTo(l, ar)
- case op @ LessThan(ar: AttributeReference, l: Literal) =>
+ case LessThan(ar: AttributeReference, l: Literal) =>
GreaterThan(l, ar)
- case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) =>
+ case LessThanOrEqual(ar: AttributeReference, l: Literal) =>
GreaterThanOrEqual(l, ar)
- case op @ GreaterThan(ar: AttributeReference, l: Literal) =>
+ case GreaterThan(ar: AttributeReference, l: Literal) =>
LessThan(l, ar)
- case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
+ case GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
LessThanOrEqual(l, ar)
}