aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2017-03-06 23:53:53 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-06 23:53:53 -0800
commit932196d9e30453e0827ee3cd8a81cb306b7a24d9 (patch)
tree9306be5bd1fa2d72aa50a8a3c11518e13a635826 /sql
parente52499ea9c32326b399b50bf0e3f26278da3feb2 (diff)
downloadspark-932196d9e30453e0827ee3cd8a81cb306b7a24d9.tar.gz
spark-932196d9e30453e0827ee3cd8a81cb306b7a24d9.tar.bz2
spark-932196d9e30453e0827ee3cd8a81cb306b7a24d9.zip
[SPARK-17075][SQL][FOLLOWUP] fix filter estimation issues
## What changes were proposed in this pull request? 1. support boolean type in binary expression estimation. 2. deal with compound Not conditions. 3. avoid convert BigInt/BigDecimal directly to double unless it's within range (0, 1). 4. reorganize test code. ## How was this patch tested? modify related test cases. Author: wangzhenhua <wangzhenhua@huawei.com> Author: Zhenhua Wang <wzh_zju@163.com> Closes #17148 from wzhfy/fixFilter.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala174
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala397
2 files changed, 297 insertions, 274 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 0c928832d7..b10785b05d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import scala.collection.immutable.HashSet
import scala.collection.mutable
+import scala.math.BigDecimal.RoundingMode
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
def estimate: Option[Statistics] = {
if (childStats.rowCount.isEmpty) return None
- // save a mutable copy of colStats so that we can later change it recursively
+ // Save a mutable copy of colStats so that we can later change it recursively.
colStatsMap.setInitValues(childStats.attributeStats)
- // estimate selectivity of this filter predicate
- val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
- case Some(percent) => percent
- // for not-supported condition, set filter selectivity to a conservative estimate 100%
- case None => 1.0
- }
+ // Estimate selectivity of this filter predicate, and update column stats if needed.
+ // For not-supported condition, set filter selectivity to a conservative estimate 100%
+ val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0)
- val newColStats = colStatsMap.toColumnStats
+ val newColStats = if (filterSelectivity == 0) {
+ // The output is empty, we don't need to keep column stats.
+ AttributeMap[ColumnStat](Nil)
+ } else {
+ colStatsMap.toColumnStats
+ }
val filteredRowCount: BigInt =
EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
@@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
/**
- * Returns a percentage of rows meeting a compound condition in Filter node.
- * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT.
+ * Returns a percentage of rows meeting a condition in Filter node.
+ * If it's a single condition, we calculate the percentage directly.
+ * If it's a compound condition, it is decomposed into multiple single conditions linked with
+ * AND, OR, NOT.
* For logical AND conditions, we need to update stats after a condition estimation
* so that the stats will be more accurate for subsequent estimation. This is needed for
* range condition such as (c > 40 AND c <= 50)
- * For logical OR conditions, we do not update stats after a condition estimation.
+ * For logical OR and NOT conditions, we do not update stats after a condition estimation.
*
* @param condition the compound logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
@@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
condition match {
case And(cond1, cond2) =>
- // For ease of debugging, we compute percent1 and percent2 in 2 statements.
- val percent1 = calculateFilterSelectivity(cond1, update)
- val percent2 = calculateFilterSelectivity(cond2, update)
- (percent1, percent2) match {
- case (Some(p1), Some(p2)) => Some(p1 * p2)
- case (Some(p1), None) => Some(p1)
- case (None, Some(p2)) => Some(p2)
- case (None, None) => None
- }
+ val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0)
+ Some(percent1 * percent2)
case Or(cond1, cond2) =>
- // For ease of debugging, we compute percent1 and percent2 in 2 statements.
- val percent1 = calculateFilterSelectivity(cond1, update = false)
- val percent2 = calculateFilterSelectivity(cond2, update = false)
- (percent1, percent2) match {
- case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2)))
- case (Some(p1), None) => Some(1.0)
- case (None, Some(p2)) => Some(1.0)
- case (None, None) => None
- }
+ val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0)
+ Some(percent1 + percent2 - (percent1 * percent2))
- case Not(cond) => calculateFilterSelectivity(cond, update = false) match {
- case Some(percent) => Some(1.0 - percent)
- // for not-supported condition, set filter selectivity to a conservative estimate 100%
- case None => None
- }
+ case Not(And(cond1, cond2)) =>
+ calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false)
+
+ case Not(Or(cond1, cond2)) =>
+ calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false)
- case _ => calculateSingleCondition(condition, update)
+ case Not(cond) =>
+ calculateFilterSelectivity(cond, update = false) match {
+ case Some(percent) => Some(1.0 - percent)
+ case None => None
+ }
+
+ case _ =>
+ calculateSingleCondition(condition, update)
}
}
@@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
val percent = if (isNull) {
- nullPercent.toDouble
+ nullPercent
} else {
- 1.0 - nullPercent.toDouble
+ 1.0 - nullPercent
}
- Some(percent)
+ Some(percent.toDouble)
}
/**
@@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
attr: Attribute,
literal: Literal,
update: Boolean): Option[Double] = {
+ if (!colStatsMap.contains(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
+ return None
+ }
+
attr.dataType match {
- case _: NumericType | DateType | TimestampType =>
+ case _: NumericType | DateType | TimestampType | BooleanType =>
evaluateBinaryForNumeric(op, attr, literal, update)
case StringType | BinaryType =>
// TODO: It is difficult to support other binary comparisons for String/Binary
// type without min/max and advanced statistics like histogram.
logDebug("[CBO] No range comparison statistics for String/Binary type " + attr)
None
- case _ =>
- // TODO: support boolean type.
- None
}
}
@@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
+ * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo;
+ * if the literal is null, the condition will be changed to IsNull after optimization.
+ * So we don't need specific logic for EqualNullSafe here.
+ *
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
colStatsMap(attr) = newStats
}
- Some(1.0 / ndv.toDouble)
+ Some((1.0 / BigDecimal(ndv)).toDouble)
} else {
Some(0.0)
}
@@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
// return the filter selectivity. Without advanced statistics such as histograms,
// we have to assume uniform distribution.
- Some(math.min(1.0, newNdv.toDouble / ndv.toDouble))
+ Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble))
}
/**
* Returns a percentage of rows meeting a binary comparison expression.
- * This method evaluate expression for Numeric columns only.
+ * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
* @param attr an Attribute (or a column)
@@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
literal: Literal,
update: Boolean): Option[Double] = {
- var percent = 1.0
val colStat = colStatsMap(attr)
- val statsRange =
- Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
+ val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
+ val max = BigDecimal(statsRange.max)
+ val min = BigDecimal(statsRange.min)
+ val ndv = BigDecimal(colStat.distinctCount)
// determine the overlapping degree between predicate range and column's range
- val literalValueBD = BigDecimal(literal.value.toString)
+ val numericLiteral = if (literal.dataType == BooleanType) {
+ if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0)
+ } else {
+ BigDecimal(literal.value.toString)
+ }
val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
case _: LessThan =>
- (literalValueBD <= statsRange.min, literalValueBD > statsRange.max)
+ (numericLiteral <= min, numericLiteral > max)
case _: LessThanOrEqual =>
- (literalValueBD < statsRange.min, literalValueBD >= statsRange.max)
+ (numericLiteral < min, numericLiteral >= max)
case _: GreaterThan =>
- (literalValueBD >= statsRange.max, literalValueBD < statsRange.min)
+ (numericLiteral >= max, numericLiteral < min)
case _: GreaterThanOrEqual =>
- (literalValueBD > statsRange.max, literalValueBD <= statsRange.min)
+ (numericLiteral > max, numericLiteral <= min)
}
+ var percent = BigDecimal(1.0)
if (noOverlap) {
percent = 0.0
} else if (completeOverlap) {
percent = 1.0
} else {
- // this is partial overlap case
- val literalDouble = literalValueBD.toDouble
- val maxDouble = BigDecimal(statsRange.max).toDouble
- val minDouble = BigDecimal(statsRange.min).toDouble
-
+ // This is the partial overlap case:
// Without advanced statistics like histogram, we assume uniform data distribution.
// We just prorate the adjusted range over the initial range to compute filter selectivity.
- // For ease of computation, we convert all relevant numeric values to Double.
+ assert(max > min)
percent = op match {
case _: LessThan =>
- (literalDouble - minDouble) / (maxDouble - minDouble)
+ if (numericLiteral == max) {
+ // If the literal value is right on the boundary, we can minus the part of the
+ // boundary value (1/ndv).
+ 1.0 - 1.0 / ndv
+ } else {
+ (numericLiteral - min) / (max - min)
+ }
case _: LessThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.min)) {
- 1.0 / colStat.distinctCount.toDouble
+ if (numericLiteral == min) {
+ // The boundary value is the only satisfying value.
+ 1.0 / ndv
} else {
- (literalDouble - minDouble) / (maxDouble - minDouble)
+ (numericLiteral - min) / (max - min)
}
case _: GreaterThan =>
- (maxDouble - literalDouble) / (maxDouble - minDouble)
+ if (numericLiteral == min) {
+ 1.0 - 1.0 / ndv
+ } else {
+ (max - numericLiteral) / (max - min)
+ }
case _: GreaterThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.max)) {
- 1.0 / colStat.distinctCount.toDouble
+ if (numericLiteral == max) {
+ 1.0 / ndv
} else {
- (maxDouble - literalDouble) / (maxDouble - minDouble)
+ (max - numericLiteral) / (max - min)
}
}
@@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
val newValue = convertBoundValue(attr.dataType, literal.value)
var newMax = colStat.max
var newMin = colStat.min
+ var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ if (newNdv < 1) newNdv = 1
+
op match {
- case _: GreaterThan => newMin = newValue
- case _: GreaterThanOrEqual => newMin = newValue
- case _: LessThan => newMax = newValue
- case _: LessThanOrEqual => newMax = newValue
+ case _: GreaterThan | _: GreaterThanOrEqual =>
+ // If new ndv is 1, then new max must be equal to new min.
+ newMin = if (newNdv == 1) newMax else newValue
+ case _: LessThan | _: LessThanOrEqual =>
+ newMax = if (newNdv == 1) newMin else newValue
}
- val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1)
- val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
- max = newMax, nullCount = 0)
+ val newStats =
+ colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0)
colStatsMap(attr) = newStats
}
}
- Some(percent)
+ Some(percent.toDouble)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 8be74ced7b..4691913c8c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation
import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.types._
@@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// Suppose our test table has 10 rows and 6 columns.
// First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
// Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
- val arInt = AttributeReference("cint", IntegerType)()
- val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ val attrInt = AttributeReference("cint", IntegerType)()
+ val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
// only 2 values
- val arBool = AttributeReference("cbool", BooleanType)()
- val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
+ val attrBool = AttributeReference("cbool", BooleanType)()
+ val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1)
// Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = Date.valueOf("2017-01-01")
val dMax = Date.valueOf("2017-01-10")
- val arDate = AttributeReference("cdate", DateType)()
- val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
+ val attrDate = AttributeReference("cdate", DateType)()
+ val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
// Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = new java.math.BigDecimal("0.200000000000000000")
val decMax = new java.math.BigDecimal("0.800000000000000000")
- val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
- val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
+ val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
+ val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
nullCount = 0, avgLen = 8, maxLen = 8)
// Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
- val arDouble = AttributeReference("cdouble", DoubleType)()
- val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
+ val attrDouble = AttributeReference("cdouble", DoubleType)()
+ val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
nullCount = 0, avgLen = 8, maxLen = 8)
// Sixth column cstring has 10 String values:
// "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
- val arString = AttributeReference("cstring", StringType)()
- val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None,
+ val attrString = AttributeReference("cstring", StringType)()
+ val colStatString = ColumnStat(distinctCount = 10, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2)
+ val attributeMap = AttributeMap(Seq(
+ attrInt -> colStatInt,
+ attrBool -> colStatBool,
+ attrDate -> colStatDate,
+ attrDecimal -> colStatDecimal,
+ attrDouble -> colStatDouble,
+ attrString -> colStatString))
+
test("cint = 2") {
validateEstimatedStats(
- arInt,
- Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cint <=> 2") {
validateEstimatedStats(
- arInt,
- Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cint = 0") {
// This is an out-of-range case since 0 is outside the range [min, max]
validateEstimatedStats(
- arInt,
- Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint < 3") {
validateEstimatedStats(
- arInt,
- Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint < 0") {
// This is a corner case since literal 0 is smaller than min.
validateEstimatedStats(
- arInt,
- Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint <= 3") {
validateEstimatedStats(
- arInt,
- Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint > 6") {
validateEstimatedStats(
- arInt,
- Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 5)
+ Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 5)
}
test("cint > 10") {
// This is a corner case since max value is 10.
validateEstimatedStats(
- arInt,
- Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint >= 6") {
validateEstimatedStats(
- arInt,
- Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 5)
+ Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 5)
}
test("cint IS NULL") {
validateEstimatedStats(
- arInt,
- Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 0, min = None, max = None,
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint IS NOT NULL") {
validateEstimatedStats(
- arInt,
- Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 10)
+ Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 10)
}
test("cint > 3 AND cint <= 6") {
- val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6)))
+ val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))
validateEstimatedStats(
- arInt,
- Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 4)
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
}
test("cint = 3 OR cint = 6") {
- val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6)))
+ val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 2)
+ }
+
+ test("Not(cint > 3 AND cint <= 6)") {
+ val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> colStatInt),
+ expectedRowCount = 6)
+ }
+
+ test("Not(cint <= 3 OR cint > 6)") {
+ val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> colStatInt),
+ expectedRowCount = 5)
+ }
+
+ test("Not(cint = 3 AND cstring < 'A8')") {
+ val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
+ Seq(attrInt -> colStatInt, attrString -> colStatString),
+ expectedRowCount = 10)
+ }
+
+ test("Not(cint = 3 OR cstring < 'A8')") {
+ val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
validateEstimatedStats(
- arInt,
- Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 2)
+ Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
+ Seq(attrInt -> colStatInt, attrString -> colStatString),
+ expectedRowCount = 9)
}
test("cint IN (3, 4, 5)") {
validateEstimatedStats(
- arInt,
- Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint NOT IN (3, 4, 5)") {
validateEstimatedStats(
- arInt,
- Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 7)
+ Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 7)
}
test("cbool = true") {
validateEstimatedStats(
- arBool,
- Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)),
- ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1),
- 5)
+ Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
}
test("cbool > false") {
- // bool comparison is not supported yet, so stats remain same.
validateEstimatedStats(
- arBool,
- Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)),
- ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1),
- 10)
+ Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
}
test("cdate = cast('2017-01-02' AS DATE)") {
val d20170102 = Date.valueOf("2017-01-02")
validateEstimatedStats(
- arDate,
- Filter(EqualTo(arDate, Literal(d20170102)),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualTo(attrDate, Literal(d20170102)),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cdate < cast('2017-01-03' AS DATE)") {
val d20170103 = Date.valueOf("2017-01-03")
validateEstimatedStats(
- arDate,
- Filter(LessThan(arDate, Literal(d20170103)),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThan(attrDate, Literal(d20170103)),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("""cdate IN ( cast('2017-01-03' AS DATE),
@@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val d20170104 = Date.valueOf("2017-01-04")
val d20170105 = Date.valueOf("2017-01-05")
validateEstimatedStats(
- arDate,
- Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cdecimal = 0.400000000000000000") {
val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
validateEstimatedStats(
- arDecimal,
- Filter(EqualTo(arDecimal, Literal(dec_0_40)),
- childStatsTestPlan(Seq(arDecimal), 4L)),
- ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 1)
+ Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
+ childStatsTestPlan(Seq(attrDecimal), 4L)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 1)
}
test("cdecimal < 0.60 ") {
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
validateEstimatedStats(
- arDecimal,
- Filter(LessThan(arDecimal, Literal(dec_0_60)),
- childStatsTestPlan(Seq(arDecimal), 4L)),
- ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 3)
+ Filter(LessThan(attrDecimal, Literal(dec_0_60)),
+ childStatsTestPlan(Seq(attrDecimal), 4L)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 3)
}
test("cdouble < 3.0") {
validateEstimatedStats(
- arDouble,
- Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 3)
+ Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)),
+ Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 3)
}
test("cstring = 'A2'") {
validateEstimatedStats(
- arString,
- Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
- ColumnStat(distinctCount = 1, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2),
- 1)
+ Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
+ Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2)),
+ expectedRowCount = 1)
}
- // There is no min/max statistics for String type. We estimate 10 rows returned.
- test("cstring < 'A2'") {
+ test("cstring < 'A2' - unsupported condition") {
validateEstimatedStats(
- arString,
- Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
- ColumnStat(distinctCount = 10, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2),
- 10)
+ Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
+ Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2)),
+ expectedRowCount = 10)
}
- // This is a corner test case. We want to test if we can handle the case when the number of
- // valid values in IN clause is greater than the number of distinct values for a given column.
- // For example, column has only 2 distinct values 1 and 6.
- // The predicate is: column IN (1, 2, 3, 4, 5).
test("cint IN (1, 2, 3, 4, 5)") {
+ // This is a corner test case. We want to test if we can handle the case when the number of
+ // valid values in IN clause is greater than the number of distinct values for a given column.
+ // For example, column has only 2 distinct values 1 and 6.
+ // The predicate is: column IN (1, 2, 3, 4, 5).
val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6),
nullCount = 0, avgLen = 4, maxLen = 4)
val cornerChildStatsTestplan = StatsTestPlan(
- outputList = Seq(arInt),
+ outputList = Seq(attrInt),
rowCount = 2L,
- attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt))
+ attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt))
)
validateEstimatedStats(
- arInt,
- Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 2)
+ Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 2)
}
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
StatsTestPlan(
outputList = outList,
rowCount = tableRowCount,
- attributeStats = AttributeMap(Seq(
- arInt -> childColStatInt,
- arBool -> childColStatBool,
- arDate -> childColStatDate,
- arDecimal -> childColStatDecimal,
- arDouble -> childColStatDouble,
- arString -> childColStatString
- ))
- )
+ attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a))))
}
private def validateEstimatedStats(
- ar: AttributeReference,
filterNode: Filter,
- expectedColStats: ColumnStat,
- rowCount: Int): Unit = {
-
- val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
- val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats)
-
- val filteredStats = filterNode.stats(conf)
- assert(filteredStats.sizeInBytes == expectedSizeInBytes)
- assert(filteredStats.rowCount.get == rowCount)
- assert(filteredStats.attributeStats(ar) == expectedColStats)
-
- // If the filter has a binary operator (including those nested inside
- // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the
- // operator, and then check again.
- val rewrittenFilter = filterNode transformExpressionsDown {
- case EqualTo(ar: AttributeReference, l: Literal) =>
- EqualTo(l, ar)
-
- case LessThan(ar: AttributeReference, l: Literal) =>
- GreaterThan(l, ar)
- case LessThanOrEqual(ar: AttributeReference, l: Literal) =>
- GreaterThanOrEqual(l, ar)
-
- case GreaterThan(ar: AttributeReference, l: Literal) =>
- LessThan(l, ar)
- case GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
- LessThanOrEqual(l, ar)
+ expectedColStats: Seq[(Attribute, ColumnStat)],
+ expectedRowCount: Int): Unit = {
+
+ // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the
+ // sides of the attribute and the literal, reverse the operator, and then check again.
+ val swappedFilter = filterNode transformExpressionsDown {
+ case EqualTo(attr: Attribute, l: Literal) =>
+ EqualTo(l, attr)
+
+ case LessThan(attr: Attribute, l: Literal) =>
+ GreaterThan(l, attr)
+ case LessThanOrEqual(attr: Attribute, l: Literal) =>
+ GreaterThanOrEqual(l, attr)
+
+ case GreaterThan(attr: Attribute, l: Literal) =>
+ LessThan(l, attr)
+ case GreaterThanOrEqual(attr: Attribute, l: Literal) =>
+ LessThanOrEqual(l, attr)
+ }
+
+ val testFilters = if (swappedFilter != filterNode) {
+ Seq(swappedFilter, filterNode)
+ } else {
+ Seq(filterNode)
}
- if (rewrittenFilter != filterNode) {
- validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount)
+ testFilters.foreach { filter =>
+ val expectedAttributeMap = AttributeMap(expectedColStats)
+ val expectedStats = Statistics(
+ sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap),
+ rowCount = Some(expectedRowCount),
+ attributeStats = expectedAttributeMap)
+ assert(filter.stats(conf) == expectedStats)
}
}
}