From e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 3 Apr 2017 17:27:12 -0700 Subject: [SPARK-19408][SQL] filter estimation on two columns of same table ## What changes were proposed in this pull request? In SQL queries, we also see predicate expressions involving two columns such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. Note that, if column-1 and column-2 belong to different tables, then it is a join operator's work, NOT a filter operator's work. This PR estimates filter selectivity on two columns of same table. For example, multiple tpc-h queries have this predicate "WHERE l_commitdate < l_receiptdate" ## How was this patch tested? We added 6 new test cases to test various logical predicates involving two columns of same table. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ron Hu Author: U-CHINA\r00754707 Closes #17415 from ron8hu/filterTwoColumns. --- .../logical/statsEstimation/FilterEstimation.scala | 233 ++++++++++++++++++++- .../statsEstimation/FilterEstimationSuite.scala | 140 ++++++++++++- 2 files changed, 363 insertions(+), 10 deletions(-) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala mode change 100644 => 100755 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala old mode 100644 new mode 100755 index b32374c574..03c76cd41d --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -201,6 +201,21 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) + case op @ Equality(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + case _ => // TODO: it's difficult to support string operators without advanced statistics. // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) @@ -257,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting a binary comparison expression. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -448,7 +463,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -550,6 +565,220 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(percent.toDouble) } + /** + * Returns a percentage of rows meeting a binary comparison expression containing two columns. + * In SQL queries, we also see predicate expressions involving two columns + * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. + * Note that, if column-1 and column-2 belong to different tables, then it is a join + * operator's work, NOT a filter operator's work. + * + * @param op a binary comparison operator, including =, <=>, <, <=, >, >= + * @param attrLeft the left Attribute (or a column) + * @param attrRight the right Attribute (or a column) + * @param update a boolean flag to specify if we need to update ColumnStat of the given columns + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForTwoColumns( + op: BinaryComparison, + attrLeft: Attribute, + attrRight: Attribute, + update: Boolean): Option[Double] = { + + if (!colStatsMap.contains(attrLeft)) { + logDebug("[CBO] No statistics for " + attrLeft) + return None + } + if (!colStatsMap.contains(attrRight)) { + logDebug("[CBO] No statistics for " + attrRight) + return None + } + + attrLeft.dataType match { + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft) + return None + case _ => + } + + val colStatLeft = colStatsMap(attrLeft) + val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericRange] + val maxLeft = BigDecimal(statsRangeLeft.max) + val minLeft = BigDecimal(statsRangeLeft.min) + + val colStatRight = colStatsMap(attrRight) + val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericRange] + val maxRight = BigDecimal(statsRangeRight.max) + val minRight = BigDecimal(statsRangeRight.min) + + // determine the overlapping degree between predicate range and column's range + val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + // Left < Right or Left <= Right + // - no overlap: + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + case _: LessThan => + (minLeft >= maxRight, (maxLeft < minRight) && allNotNull) + case _: LessThanOrEqual => + (minLeft > maxRight, (maxLeft <= minRight) && allNotNull) + + // Left > Right or Left >= Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + case _: GreaterThan => + (maxLeft <= minRight, (minLeft > maxRight) && allNotNull) + case _: GreaterThanOrEqual => + (maxLeft < minRight, (minLeft >= maxRight) && allNotNull) + + // Left = Right or Left <=> Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: + // minLeft maxLeft + // minRight maxRight + // --------+------------------+-------> + case _: EqualTo => + ((maxLeft < minRight) || (maxRight < minLeft), + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + case _: EqualNullSafe => + // For null-safe equality, we use a very restrictive condition to evaluate its overlap. + // If null values exists, we set it to partial overlap. + (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + } + + var percent = BigDecimal(1.0) + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // For partial overlap, we use an empirical value 1/3 as suggested by the book + // "Database Systems, the complete book". + percent = 1.0 / 3.0 + + if (update) { + // Need to adjust new min/max after the filter condition is applied + + val ndvLeft = BigDecimal(colStatLeft.distinctCount) + var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvLeft < 1) newNdvLeft = 1 + val ndvRight = BigDecimal(colStatRight.distinctCount) + var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvRight < 1) newNdvRight = 1 + + var newMaxLeft = colStatLeft.max + var newMinLeft = colStatLeft.min + var newMaxRight = colStatRight.max + var newMinRight = colStatRight.min + + op match { + case _: LessThan | _: LessThanOrEqual => + // the left side should be less than the right side. + // If not, we need to adjust it to narrow the range. + // Left < Right or Left <= Right + // minRight < minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxRight < maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft > minRight) newMinRight = colStatLeft.min + if (maxLeft > maxRight) newMaxLeft = colStatRight.max + + case _: GreaterThan | _: GreaterThanOrEqual => + // the left side should be greater than the right side. + // If not, we need to adjust it to narrow the range. + // Left > Right or Left >= Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + if (minLeft < minRight) newMinLeft = colStatRight.min + if (maxLeft < maxRight) newMaxRight = colStatLeft.max + + case _: EqualTo | _: EqualNullSafe => + // need to set new min to the larger min value, and + // set the new max to the smaller max value. + // Left = Right or Left <=> Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // minRight <= minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + // + // maxRight <= maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft < minRight) { + newMinLeft = colStatRight.min + } else { + newMinRight = colStatLeft.min + } + if (maxLeft < maxRight) { + newMaxRight = colStatLeft.max + } else { + newMaxLeft = colStatRight.max + } + } + + val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + max = newMaxLeft) + colStatsMap(attrLeft) = newStatsLeft + val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + max = newMaxRight) + colStatsMap(attrRight) = newStatsRight + } + } + + Some(percent.toDouble) + } + } class ColumnStatsMap { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala old mode 100644 new mode 100755 index 1966c96c05..cffb0d8739 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -33,49 +33,74 @@ import org.apache.spark.sql.types._ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. - // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) - // only 2 values + // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) - // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) - // Sixth column cstring has 10 String values: + // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint < cint2 + val attrInt2 = AttributeReference("cint2", IntegerType)() + val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint = cint3 without overlap at all. + val attrInt3 = AttributeReference("cint3", IntegerType)() + val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint4 has values in the range from 1 to 10 + // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test complete overlap + val attrInt4 = AttributeReference("cint4", IntegerType)() + val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, attrBool -> colStatBool, attrDate -> colStatDate, attrDecimal -> colStatDecimal, attrDouble -> colStatDouble, - attrString -> colStatString)) + attrString -> colStatString, + attrInt2 -> colStatInt2, + attrInt3 -> colStatInt3, + attrInt4 -> colStatInt4 + )) test("true") { validateEstimatedStats( @@ -450,6 +475,89 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } } + test("cint = cint2") { + // partial overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint > cint2") { + // partial overlap case + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint < cint2") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint4") { + // complete overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint < cint4") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint3") { + // no records qualify due to no overlap + val emptyColStats = Seq[(Attribute, ColumnStat)]() + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + + test("cint < cint3") { + // all table records qualify. + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint > cint3") { + // no records qualify due to no overlap + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, @@ -491,7 +599,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - assert(filter.stats(conf) == expectedStats) + + val filterStats = filter.stats(conf) + assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) + assert(filterStats.rowCount == expectedStats.rowCount) + val rowCountValue = filterStats.rowCount.getOrElse(0) + // check the output column stats if the row count is > 0. + // When row count is 0, the output is set to empty. + if (rowCountValue != 0) { + // Need to check attributeStats one by one because we may have multiple output columns. + // Due to update operation, the output columns may be in different order. + assert(expectedColStats.size == filterStats.attributeStats.size) + expectedColStats.foreach { kv => + val filterColumnStat = filterStats.attributeStats.get(kv._1).get + assert(filterColumnStat == kv._2) + } + } } } + } -- cgit v1.2.3