aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hövell tot Westerflier <hvanhovell@questtec.nl>2016-06-12 21:30:32 -0700
committerReynold Xin <rxin@databricks.com>2016-06-12 21:30:32 -0700
commit1f8f2b5c2a33e63367ea4881b5918f6bc0a6f52f (patch)
tree5d35fcdd61d1fc2eb2554d55db291b9d5248707f
parentf5d38c39255cc75325c6639561bfec1bc051f788 (diff)
downloadspark-1f8f2b5c2a33e63367ea4881b5918f6bc0a6f52f.tar.gz
spark-1f8f2b5c2a33e63367ea4881b5918f6bc0a6f52f.tar.bz2
spark-1f8f2b5c2a33e63367ea4881b5918f6bc0a6f52f.zip
[SPARK-15370][SQL] Fix count bug
# What changes were proposed in this pull request? This pull request fixes the COUNT bug in the `RewriteCorrelatedScalarSubquery` rule. After this change, the rule tests the expression at the root of the correlated subquery to determine whether the expression returns `NULL` on empty input. If the expression does not return `NULL`, the rule generates additional logic in the `Project` operator above the rewritten subquery. This additional logic intercepts `NULL` values coming from the outer join and replaces them with the value that the subquery's expression would return on empty input. This PR takes over https://github.com/apache/spark/pull/13155. It only fixes an issue with `Literal` construction and style issues. All credits should go frreiss. # How was this patch tested? Added regression tests to cover all branches of the updated rule (see changes to `SubquerySuite`). Ran all existing automated regression tests after merging with latest trunk. Author: frreiss <frreiss@us.ibm.com> Author: Herman van Hovell <hvanhovell@databricks.com> Closes #13629 from hvanhovell/SPARK-15370-cleanup.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala221
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala81
3 files changed, 287 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 8a6cf53782..a3b098afe5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -69,8 +69,11 @@ trait PredicateHelper {
protected def replaceAlias(
condition: Expression,
aliases: AttributeMap[Expression]): Expression = {
- condition.transform {
- case a: Attribute => aliases.getOrElse(a, a)
+ // Use transformUp to prevent infinite recursion when the replacement expression
+ // redefines the same ExprId,
+ condition.transformUp {
+ case a: Attribute =>
+ aliases.getOrElse(a, a)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a12c2ef0fd..7b9b21f416 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -528,7 +528,8 @@ object CollapseProject extends Rule[LogicalPlan] {
// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
- val rewrittenUpper = upper.map(_.transform {
+ // Use transformUp to prevent infinite recursion.
+ val rewrittenUpper = upper.map(_.transformUp {
case a: Attribute => aliases.getOrElse(a, a)
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
@@ -1715,10 +1716,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, PredicateSubquery(sub, conditions, _, _)) =>
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
@@ -1727,11 +1728,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceLeftOption(And), p)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or)
Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get)))
case (p, predicate) =>
- val (newCond, inputPlan) = rewriteExistentialExpr(Option(predicate), p)
+ val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
}
}
@@ -1744,22 +1745,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
* are blocked in the Analyzer.
*/
private def rewriteExistentialExpr(
- expr: Option[Expression],
+ exprs: Seq[Expression],
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
var newPlan = plan
- expr match {
- case Some(e) =>
- val newExpr = e transformUp {
- case PredicateSubquery(sub, conditions, nullAware, _) =>
- // TODO: support null-aware join
- val exists = AttributeReference("exists", BooleanType, nullable = false)()
- newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
- exists
+ val newExprs = exprs.map { e =>
+ e transformUp {
+ case PredicateSubquery(sub, conditions, nullAware, _) =>
+ // TODO: support null-aware join
+ val exists = AttributeReference("exists", BooleanType, nullable = false)()
+ newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
+ exists
}
- (Option(newExpr), newPlan)
- case None =>
- (expr, plan)
}
+ (newExprs.reduceOption(And), newPlan)
}
}
@@ -1783,6 +1781,124 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
/**
+ * Statically evaluate an expression containing zero or more placeholders, given a set
+ * of bindings for placeholder values.
+ */
+ private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
+ val rewrittenExpr = expr transform {
+ case r: AttributeReference =>
+ bindings(r.exprId) match {
+ case Some(v) => Literal.create(v, r.dataType)
+ case None => Literal.default(NullType)
+ }
+ }
+ Option(rewrittenExpr.eval())
+ }
+
+ /**
+ * Statically evaluate an expression containing one or more aggregates on an empty input.
+ */
+ private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
+ // AggregateExpressions are Unevaluable, so we need to replace all aggregates
+ // in the expression with the value they would return for zero input tuples.
+ // Also replace attribute refs (for example, for grouping columns) with NULL.
+ val rewrittenExpr = expr transform {
+ case a @ AggregateExpression(aggFunc, _, _, resultId) =>
+ aggFunc.defaultResult.getOrElse(Literal.default(NullType))
+
+ case _: AttributeReference => Literal.default(NullType)
+ }
+ Option(rewrittenExpr.eval())
+ }
+
+ /**
+ * Statically evaluate a scalar subquery on an empty input.
+ *
+ * <b>WARNING:</b> This method only covers subqueries that pass the checks under
+ * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
+ * CheckAnalysis become less restrictive, this method will need to change.
+ */
+ private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
+ // Inputs to this method will start with a chain of zero or more SubqueryAlias
+ // and Project operators, followed by an optional Filter, followed by an
+ // Aggregate. Traverse the operators recursively.
+ def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
+ case SubqueryAlias(_, child) => evalPlan(child)
+ case Filter(condition, child) =>
+ val bindings = evalPlan(child)
+ if (bindings.isEmpty) bindings
+ else {
+ val exprResult = evalExpr(condition, bindings).getOrElse(false)
+ .asInstanceOf[Boolean]
+ if (exprResult) bindings else Map.empty
+ }
+
+ case Project(projectList, child) =>
+ val bindings = evalPlan(child)
+ if (bindings.isEmpty) {
+ bindings
+ } else {
+ projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
+ }
+
+ case Aggregate(_, aggExprs, _) =>
+ // Some of the expressions under the Aggregate node are the join columns
+ // for joining with the outer query block. Fill those expressions in with
+ // nulls and statically evaluate the remainder.
+ aggExprs.map {
+ case ref: AttributeReference => (ref.exprId, None)
+ case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
+ case ne => (ne.exprId, evalAggOnZeroTups(ne))
+ }.toMap
+
+ case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
+ }
+
+ val resultMap = evalPlan(plan)
+
+ // By convention, the scalar subquery result is the leftmost field.
+ resultMap(plan.output.head.exprId)
+ }
+
+ /**
+ * Split the plan for a scalar subquery into the parts above the innermost query block
+ * (first part of returned value), the HAVING clause of the innermost query block
+ * (optional second part) and the parts below the HAVING CLAUSE (third part).
+ */
+ private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = {
+ val topPart = ArrayBuffer.empty[LogicalPlan]
+ var bottomPart: LogicalPlan = plan
+ while (true) {
+ bottomPart match {
+ case havingPart @ Filter(_, aggPart: Aggregate) =>
+ return (topPart, Option(havingPart), aggPart)
+
+ case aggPart: Aggregate =>
+ // No HAVING clause
+ return (topPart, None, aggPart)
+
+ case p @ Project(_, child) =>
+ topPart += p
+ bottomPart = child
+
+ case s @ SubqueryAlias(_, child) =>
+ topPart += s
+ bottomPart = child
+
+ case Filter(_, op) =>
+ sys.error(s"Correlated subquery has unexpected operator $op below filter")
+
+ case op @ _ => sys.error(s"Unexpected operator $op in correlated subquery")
+ }
+ }
+
+ sys.error("This line should be unreachable")
+ }
+
+ // Name of generated column used in rewrite below
+ val ALWAYS_TRUE_COLNAME = "alwaysTrue"
+
+ /**
* Construct a new child plan by left joining the given subqueries to a base plan.
*/
private def constructLeftJoins(
@@ -1790,9 +1906,74 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
- Project(
- currentChild.output :+ query.output.head,
- Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+ val origOutput = query.output.head
+
+ val resultWithZeroTups = evalSubqueryOnZeroTups(query)
+ if (resultWithZeroTups.isEmpty) {
+ // CASE 1: Subquery guaranteed not to have the COUNT bug
+ Project(
+ currentChild.output :+ origOutput,
+ Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+ } else {
+ // Subquery might have the COUNT bug. Add appropriate corrections.
+ val (topPart, havingNode, aggNode) = splitSubquery(query)
+
+ // The next two cases add a leading column to the outer join input to make it
+ // possible to distinguish between the case when no tuples join and the case
+ // when the tuple that joins contains null values.
+ // The leading column always has the value TRUE.
+ val alwaysTrueExprId = NamedExpression.newExprId
+ val alwaysTrueExpr = Alias(Literal.TrueLiteral,
+ ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
+ val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
+ BooleanType)(exprId = alwaysTrueExprId)
+
+ val aggValRef = query.output.head
+
+ if (havingNode.isEmpty) {
+ // CASE 2: Subquery with no HAVING clause
+ Project(
+ currentChild.output :+
+ Alias(
+ If(IsNull(alwaysTrueRef),
+ Literal.create(resultWithZeroTups.get, origOutput.dataType),
+ aggValRef), origOutput.name)(exprId = origOutput.exprId),
+ Join(currentChild,
+ Project(query.output :+ alwaysTrueExpr, query),
+ LeftOuter, conditions.reduceOption(And)))
+
+ } else {
+ // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
+ // Need to modify any operators below the join to pass through all columns
+ // referenced in the HAVING clause.
+ var subqueryRoot: UnaryNode = aggNode
+ val havingInputs: Seq[NamedExpression] = aggNode.output
+
+ topPart.reverse.foreach {
+ case Project(projList, _) =>
+ subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
+ case s @ SubqueryAlias(alias, _) =>
+ subqueryRoot = SubqueryAlias(alias, subqueryRoot)
+ case op => sys.error(s"Unexpected operator $op in corelated subquery")
+ }
+
+ // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
+ // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
+ // ELSE (aggregate value) END AS (original column name)
+ val caseExpr = Alias(CaseWhen(Seq(
+ (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
+ (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
+ aggValRef),
+ origOutput.name)(exprId = origOutput.exprId)
+
+ Project(
+ currentChild.output :+ caseExpr,
+ Join(currentChild,
+ Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
+ LeftOuter, conditions.reduceOption(And)))
+
+ }
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 1a99fb683e..1d9ff21dbf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -490,4 +490,85 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
""".stripMargin),
Row(3) :: Nil)
}
+
+ test("SPARK-15370: COUNT bug in WHERE clause (Filter)") {
+ // Case 1: Canonical example of the COUNT bug
+ checkAnswer(
+ sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"),
+ Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+ // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
+ // a rewrite that is vulnerable to the COUNT bug
+ checkAnswer(
+ sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"),
+ Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+ // Case 3: COUNT bug without a COUNT aggregate
+ checkAnswer(
+ sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"),
+ Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug in SELECT clause (Project)") {
+ checkAnswer(
+ sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"),
+ Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
+ :: Row(null, 0) :: Row(6, 1) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug in HAVING clause (Filter)") {
+ checkAnswer(
+ sql("select l.a as grp_a from l group by l.a " +
+ "having (select count(*) from r where grp_a = r.c) = 0 " +
+ "order by grp_a"),
+ Row(null) :: Row(1) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug in Aggregate") {
+ checkAnswer(
+ sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " +
+ "from l group by l.a order by aval"),
+ Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug negative examples") {
+ // Case 1: Potential COUNT bug case that was working correctly prior to the fix
+ checkAnswer(
+ sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"),
+ Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
+ // Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
+ checkAnswer(
+ sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"),
+ Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
+ // Case 3: COUNT inside aggregate expression but no COUNT bug.
+ checkAnswer(
+ sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"),
+ Nil)
+ }
+
+ test("SPARK-15370: COUNT bug in subquery in subquery in subquery") {
+ checkAnswer(
+ sql("""select l.a from l
+ |where (
+ | select cntPlusOne + 1 as cntPlusTwo from (
+ | select cnt + 1 as cntPlusOne from (
+ | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0
+ | )
+ | )
+ |) = 2""".stripMargin),
+ Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug with nasty predicate expr") {
+ checkAnswer(
+ sql("select l.a from l where " +
+ "(select case when count(*) = 1 then null else count(*) end as cnt " +
+ "from r where l.a = r.c) = 0"),
+ Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+ }
+
+ test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") {
+ checkAnswer(
+ sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"),
+ Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
+ Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
+ }
}