aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorRon Hu <ron.hu@huawei.com>2017-04-03 17:27:12 -0700
committerXiao Li <gatorsmile@gmail.com>2017-04-03 17:27:12 -0700
commite7877fd4728ed41e440d7c4d8b6b02bd0d9e873e (patch)
tree9d9619970eb8f9392edfe2c4d94f1d5234e8093d /sql/catalyst
parent58c9e6e77ae26345291dd9fce2c57aadcc36f66c (diff)
downloadspark-e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e.tar.gz
spark-e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e.tar.bz2
spark-e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e.zip
[SPARK-19408][SQL] filter estimation on two columns of same table
## What changes were proposed in this pull request? In SQL queries, we also see predicate expressions involving two columns such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. Note that, if column-1 and column-2 belong to different tables, then it is a join operator's work, NOT a filter operator's work. This PR estimates filter selectivity on two columns of same table. For example, multiple tpc-h queries have this predicate "WHERE l_commitdate < l_receiptdate" ## How was this patch tested? We added 6 new test cases to test various logical predicates involving two columns of same table. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ron Hu <ron.hu@huawei.com> Author: U-CHINA\r00754707 <r00754707@R00754707-SC04.china.huawei.com> Closes #17415 from ron8hu/filterTwoColumns.
Diffstat (limited to 'sql/catalyst')
-rwxr-xr-x[-rw-r--r--]sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala233
-rwxr-xr-x[-rw-r--r--]sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala140
2 files changed, 363 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index b32374c574..03c76cd41d 100644..100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -201,6 +201,21 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] =>
evaluateNullCheck(ar, isNull = false, update)
+ case op @ Equality(attrLeft: Attribute, attrRight: Attribute) =>
+ evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+ case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) =>
+ evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+ case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) =>
+ evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+ case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) =>
+ evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
+ case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) =>
+ evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update)
+
case _ =>
// TODO: it's difficult to support string operators without advanced statistics.
// Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _)
@@ -257,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
/**
* Returns a percentage of rows meeting a binary comparison expression.
*
- * @param op a binary comparison operator uch as =, <, <=, >, >=
+ * @param op a binary comparison operator such as =, <, <=, >, >=
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -448,7 +463,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting a binary comparison expression.
* This method evaluate expression for Numeric/Date/Timestamp/Boolean columns.
*
- * @param op a binary comparison operator uch as =, <, <=, >, >=
+ * @param op a binary comparison operator such as =, <, <=, >, >=
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -550,6 +565,220 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
Some(percent.toDouble)
}
+ /**
+ * Returns a percentage of rows meeting a binary comparison expression containing two columns.
+ * In SQL queries, we also see predicate expressions involving two columns
+ * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table.
+ * Note that, if column-1 and column-2 belong to different tables, then it is a join
+ * operator's work, NOT a filter operator's work.
+ *
+ * @param op a binary comparison operator, including =, <=>, <, <=, >, >=
+ * @param attrLeft the left Attribute (or a column)
+ * @param attrRight the right Attribute (or a column)
+ * @param update a boolean flag to specify if we need to update ColumnStat of the given columns
+ * for subsequent conditions
+ * @return an optional double value to show the percentage of rows meeting a given condition
+ */
+ def evaluateBinaryForTwoColumns(
+ op: BinaryComparison,
+ attrLeft: Attribute,
+ attrRight: Attribute,
+ update: Boolean): Option[Double] = {
+
+ if (!colStatsMap.contains(attrLeft)) {
+ logDebug("[CBO] No statistics for " + attrLeft)
+ return None
+ }
+ if (!colStatsMap.contains(attrRight)) {
+ logDebug("[CBO] No statistics for " + attrRight)
+ return None
+ }
+
+ attrLeft.dataType match {
+ case StringType | BinaryType =>
+ // TODO: It is difficult to support other binary comparisons for String/Binary
+ // type without min/max and advanced statistics like histogram.
+ logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft)
+ return None
+ case _ =>
+ }
+
+ val colStatLeft = colStatsMap(attrLeft)
+ val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
+ .asInstanceOf[NumericRange]
+ val maxLeft = BigDecimal(statsRangeLeft.max)
+ val minLeft = BigDecimal(statsRangeLeft.min)
+
+ val colStatRight = colStatsMap(attrRight)
+ val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
+ .asInstanceOf[NumericRange]
+ val maxRight = BigDecimal(statsRangeRight.max)
+ val minRight = BigDecimal(statsRangeRight.min)
+
+ // determine the overlapping degree between predicate range and column's range
+ val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
+ val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
+ // Left < Right or Left <= Right
+ // - no overlap:
+ // minRight maxRight minLeft maxLeft
+ // --------+------------------+------------+-------------+------->
+ // - complete overlap: (If null values exists, we set it to partial overlap.)
+ // minLeft maxLeft minRight maxRight
+ // --------+------------------+------------+-------------+------->
+ case _: LessThan =>
+ (minLeft >= maxRight, (maxLeft < minRight) && allNotNull)
+ case _: LessThanOrEqual =>
+ (minLeft > maxRight, (maxLeft <= minRight) && allNotNull)
+
+ // Left > Right or Left >= Right
+ // - no overlap:
+ // minLeft maxLeft minRight maxRight
+ // --------+------------------+------------+-------------+------->
+ // - complete overlap: (If null values exists, we set it to partial overlap.)
+ // minRight maxRight minLeft maxLeft
+ // --------+------------------+------------+-------------+------->
+ case _: GreaterThan =>
+ (maxLeft <= minRight, (minLeft > maxRight) && allNotNull)
+ case _: GreaterThanOrEqual =>
+ (maxLeft < minRight, (minLeft >= maxRight) && allNotNull)
+
+ // Left = Right or Left <=> Right
+ // - no overlap:
+ // minLeft maxLeft minRight maxRight
+ // --------+------------------+------------+-------------+------->
+ // minRight maxRight minLeft maxLeft
+ // --------+------------------+------------+-------------+------->
+ // - complete overlap:
+ // minLeft maxLeft
+ // minRight maxRight
+ // --------+------------------+------->
+ case _: EqualTo =>
+ ((maxLeft < minRight) || (maxRight < minLeft),
+ (minLeft == minRight) && (maxLeft == maxRight) && allNotNull
+ && (colStatLeft.distinctCount == colStatRight.distinctCount)
+ )
+ case _: EqualNullSafe =>
+ // For null-safe equality, we use a very restrictive condition to evaluate its overlap.
+ // If null values exists, we set it to partial overlap.
+ (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull,
+ (minLeft == minRight) && (maxLeft == maxRight) && allNotNull
+ && (colStatLeft.distinctCount == colStatRight.distinctCount)
+ )
+ }
+
+ var percent = BigDecimal(1.0)
+ if (noOverlap) {
+ percent = 0.0
+ } else if (completeOverlap) {
+ percent = 1.0
+ } else {
+ // For partial overlap, we use an empirical value 1/3 as suggested by the book
+ // "Database Systems, the complete book".
+ percent = 1.0 / 3.0
+
+ if (update) {
+ // Need to adjust new min/max after the filter condition is applied
+
+ val ndvLeft = BigDecimal(colStatLeft.distinctCount)
+ var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ if (newNdvLeft < 1) newNdvLeft = 1
+ val ndvRight = BigDecimal(colStatRight.distinctCount)
+ var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ if (newNdvRight < 1) newNdvRight = 1
+
+ var newMaxLeft = colStatLeft.max
+ var newMinLeft = colStatLeft.min
+ var newMaxRight = colStatRight.max
+ var newMinRight = colStatRight.min
+
+ op match {
+ case _: LessThan | _: LessThanOrEqual =>
+ // the left side should be less than the right side.
+ // If not, we need to adjust it to narrow the range.
+ // Left < Right or Left <= Right
+ // minRight < minLeft
+ // --------+******************+------->
+ // filtered ^
+ // |
+ // newMinRight
+ //
+ // maxRight < maxLeft
+ // --------+******************+------->
+ // ^ filtered
+ // |
+ // newMaxLeft
+ if (minLeft > minRight) newMinRight = colStatLeft.min
+ if (maxLeft > maxRight) newMaxLeft = colStatRight.max
+
+ case _: GreaterThan | _: GreaterThanOrEqual =>
+ // the left side should be greater than the right side.
+ // If not, we need to adjust it to narrow the range.
+ // Left > Right or Left >= Right
+ // minLeft < minRight
+ // --------+******************+------->
+ // filtered ^
+ // |
+ // newMinLeft
+ //
+ // maxLeft < maxRight
+ // --------+******************+------->
+ // ^ filtered
+ // |
+ // newMaxRight
+ if (minLeft < minRight) newMinLeft = colStatRight.min
+ if (maxLeft < maxRight) newMaxRight = colStatLeft.max
+
+ case _: EqualTo | _: EqualNullSafe =>
+ // need to set new min to the larger min value, and
+ // set the new max to the smaller max value.
+ // Left = Right or Left <=> Right
+ // minLeft < minRight
+ // --------+******************+------->
+ // filtered ^
+ // |
+ // newMinLeft
+ //
+ // minRight <= minLeft
+ // --------+******************+------->
+ // filtered ^
+ // |
+ // newMinRight
+ //
+ // maxLeft < maxRight
+ // --------+******************+------->
+ // ^ filtered
+ // |
+ // newMaxRight
+ //
+ // maxRight <= maxLeft
+ // --------+******************+------->
+ // ^ filtered
+ // |
+ // newMaxLeft
+ if (minLeft < minRight) {
+ newMinLeft = colStatRight.min
+ } else {
+ newMinRight = colStatLeft.min
+ }
+ if (maxLeft < maxRight) {
+ newMaxRight = colStatLeft.max
+ } else {
+ newMaxLeft = colStatRight.max
+ }
+ }
+
+ val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft,
+ max = newMaxLeft)
+ colStatsMap(attrLeft) = newStatsLeft
+ val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight,
+ max = newMaxRight)
+ colStatsMap(attrRight) = newStatsRight
+ }
+ }
+
+ Some(percent.toDouble)
+ }
+
}
class ColumnStatsMap {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 1966c96c05..cffb0d8739 100644..100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -33,49 +33,74 @@ import org.apache.spark.sql.types._
class FilterEstimationSuite extends StatsEstimationTestBase {
// Suppose our test table has 10 rows and 6 columns.
- // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
+ // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
// Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
val attrInt = AttributeReference("cint", IntegerType)()
val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
- // only 2 values
+ // column cbool has only 2 distinct values
val attrBool = AttributeReference("cbool", BooleanType)()
val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1)
- // Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
+ // column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = Date.valueOf("2017-01-01")
val dMax = Date.valueOf("2017-01-10")
val attrDate = AttributeReference("cdate", DateType)()
val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
- // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
+ // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = new java.math.BigDecimal("0.200000000000000000")
val decMax = new java.math.BigDecimal("0.800000000000000000")
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
nullCount = 0, avgLen = 8, maxLen = 8)
- // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
+ // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
val attrDouble = AttributeReference("cdouble", DoubleType)()
val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
nullCount = 0, avgLen = 8, maxLen = 8)
- // Sixth column cstring has 10 String values:
+ // column cstring has 10 String values:
// "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
val attrString = AttributeReference("cstring", StringType)()
val colStatString = ColumnStat(distinctCount = 10, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2)
+ // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+ // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4
+ // This column is created to test "cint < cint2
+ val attrInt2 = AttributeReference("cint2", IntegerType)()
+ val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39
+ // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4
+ // This column is created to test "cint = cint3 without overlap at all.
+ val attrInt3 = AttributeReference("cint3", IntegerType)()
+ val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ // column cint4 has values in the range from 1 to 10
+ // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+ // This column is created to test complete overlap
+ val attrInt4 = AttributeReference("cint4", IntegerType)()
+ val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
val attributeMap = AttributeMap(Seq(
attrInt -> colStatInt,
attrBool -> colStatBool,
attrDate -> colStatDate,
attrDecimal -> colStatDecimal,
attrDouble -> colStatDouble,
- attrString -> colStatString))
+ attrString -> colStatString,
+ attrInt2 -> colStatInt2,
+ attrInt3 -> colStatInt3,
+ attrInt4 -> colStatInt4
+ ))
test("true") {
validateEstimatedStats(
@@ -450,6 +475,89 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
}
}
+ test("cint = cint2") {
+ // partial overlap case
+ validateEstimatedStats(
+ Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
+ }
+
+ test("cint > cint2") {
+ // partial overlap case
+ validateEstimatedStats(
+ Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
+ }
+
+ test("cint < cint2") {
+ // partial overlap case
+ validateEstimatedStats(
+ Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
+ }
+
+ test("cint = cint4") {
+ // complete overlap case
+ validateEstimatedStats(
+ Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 10)
+ }
+
+ test("cint < cint4") {
+ // partial overlap case
+ validateEstimatedStats(
+ Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
+ }
+
+ test("cint = cint3") {
+ // no records qualify due to no overlap
+ val emptyColStats = Seq[(Attribute, ColumnStat)]()
+ validateEstimatedStats(
+ Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+ Nil, // set to empty
+ expectedRowCount = 0)
+ }
+
+ test("cint < cint3") {
+ // all table records qualify.
+ validateEstimatedStats(
+ Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 10)
+ }
+
+ test("cint > cint3") {
+ // no records qualify due to no overlap
+ validateEstimatedStats(
+ Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
+ Nil, // set to empty
+ expectedRowCount = 0)
+ }
+
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
StatsTestPlan(
outputList = outList,
@@ -491,7 +599,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap),
rowCount = Some(expectedRowCount),
attributeStats = expectedAttributeMap)
- assert(filter.stats(conf) == expectedStats)
+
+ val filterStats = filter.stats(conf)
+ assert(filterStats.sizeInBytes == expectedStats.sizeInBytes)
+ assert(filterStats.rowCount == expectedStats.rowCount)
+ val rowCountValue = filterStats.rowCount.getOrElse(0)
+ // check the output column stats if the row count is > 0.
+ // When row count is 0, the output is set to empty.
+ if (rowCountValue != 0) {
+ // Need to check attributeStats one by one because we may have multiple output columns.
+ // Due to update operation, the output columns may be in different order.
+ assert(expectedColStats.size == filterStats.attributeStats.size)
+ expectedColStats.foreach { kv =>
+ val filterColumnStat = filterStats.attributeStats.get(kv._1).get
+ assert(filterColumnStat == kv._2)
+ }
+ }
}
}
+
}