diff options
author | Herman van Hovell <hvanhovell@databricks.com> | 2016-06-12 15:06:37 -0700 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-06-12 15:06:37 -0700 |
commit | 20b8f2c32af696c3856221c4c4fcd12c3f068af2 (patch) | |
tree | a5b4563acae57a60b1175c51533818526969e1ad | |
parent | e3554605b36bdce63ac180cc66dbdee5c1528ec7 (diff) | |
download | spark-20b8f2c32af696c3856221c4c4fcd12c3f068af2.tar.gz spark-20b8f2c32af696c3856221c4c4fcd12c3f068af2.tar.bz2 spark-20b8f2c32af696c3856221c4c4fcd12c3f068af2.zip |
[SPARK-15370][SQL] Revert PR "Update RewriteCorrelatedSuquery rule"
This reverts commit 9770f6ee60f6834e4e1200234109120427a5cc0d.
Author: Herman van Hovell <hvanhovell@databricks.com>
Closes #13626 from hvanhovell/SPARK-15370-revert.
3 files changed, 6 insertions, 280 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a3b098afe5..8a6cf53782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,11 +69,8 @@ trait PredicateHelper { protected def replaceAlias( condition: Expression, aliases: AttributeMap[Expression]): Expression = { - // Use transformUp to prevent infinite recursion when the replacement expression - // redefines the same ExprId, - condition.transformUp { - case a: Attribute => - aliases.getOrElse(a, a) + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d115274b2f..a12c2ef0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -528,8 +528,7 @@ object CollapseProject extends Rule[LogicalPlan] { // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // Use transformUp to prevent infinite recursion. - val rewrittenUpper = upper.map(_.transformUp { + val rewrittenUpper = upper.map(_.transform { case a: Attribute => aliases.getOrElse(a, a) }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. @@ -1784,128 +1783,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } /** - * Statically evaluate an expression containing zero or more placeholders, given a set - * of bindings for placeholder values. - */ - private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { - val rewrittenExpr = expr transform { - case r @ AttributeReference(_, dataType, _, _) => - bindings(r.exprId) match { - case Some(v) => Literal.create(v, dataType) - case None => Literal.default(NullType) - } - } - Option(rewrittenExpr.eval()) - } - - /** - * Statically evaluate an expression containing one or more aggregates on an empty input. - */ - private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { - // AggregateExpressions are Unevaluable, so we need to replace all aggregates - // in the expression with the value they would return for zero input tuples. - // Also replace attribute refs (for example, for grouping columns) with NULL. - val rewrittenExpr = expr transform { - case a @ AggregateExpression(aggFunc, _, _, resultId) => - aggFunc.defaultResult.getOrElse(Literal.default(NullType)) - - case AttributeReference(_, _, _, _) => Literal.default(NullType) - } - Option(rewrittenExpr.eval()) - } - - /** - * Statically evaluate a scalar subquery on an empty input. - * - * <b>WARNING:</b> This method only covers subqueries that pass the checks under - * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in - * CheckAnalysis become less restrictive, this method will need to change. - */ - private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { - // Inputs to this method will start with a chain of zero or more SubqueryAlias - // and Project operators, followed by an optional Filter, followed by an - // Aggregate. Traverse the operators recursively. - def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = { - lp match { - case SubqueryAlias(_, child) => evalPlan(child) - case Filter(condition, child) => - val bindings = evalPlan(child) - if (bindings.isEmpty) bindings - else { - val exprResult = evalExpr(condition, bindings).getOrElse(false) - .asInstanceOf[Boolean] - if (exprResult) bindings else Map.empty - } - - case Project(projectList, child) => - val bindings = evalPlan(child) - if (bindings.isEmpty) { - bindings - } else { - projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap - } - - case Aggregate(_, aggExprs, _) => - // Some of the expressions under the Aggregate node are the join columns - // for joining with the outer query block. Fill those expressions in with - // nulls and statically evaluate the remainder. - aggExprs.map(ne => ne match { - case AttributeReference(_, _, _, _) => (ne.exprId, None) - case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None) - case _ => (ne.exprId, evalAggOnZeroTups(ne)) - }).toMap - - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") - } - } - - val resultMap = evalPlan(plan) - - // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) - } - - /** - * Split the plan for a scalar subquery into the parts above the innermost query block - * (first part of returned value), the HAVING clause of the innermost query block - * (optional second part) and the parts below the HAVING CLAUSE (third part). - */ - private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { - val topPart = ArrayBuffer.empty[LogicalPlan] - var bottomPart : LogicalPlan = plan - while (true) { - bottomPart match { - case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) => - return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate]) - - case aggPart@Aggregate(_, _, _) => - // No HAVING clause - return (topPart, None, aggPart) - - case p@Project(_, child) => - topPart += p - bottomPart = child - - case s@SubqueryAlias(_, child) => - topPart += s - bottomPart = child - - case Filter(_, op@_) => - sys.error(s"Correlated subquery has unexpected operator $op below filter") - - case op@_ => sys.error(s"Unexpected operator $op in correlated subquery") - } - } - - sys.error("This line should be unreachable") - } - - - - // Name of generated column used in rewrite below - val ALWAYS_TRUE_COLNAME = "alwaysTrue" - - /** * Construct a new child plan by left joining the given subqueries to a base plan. */ private def constructLeftJoins( @@ -1913,76 +1790,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(query, conditions, _)) => - val origOutput = query.output.head - - val resultWithZeroTups = evalSubqueryOnZeroTups(query) - if (resultWithZeroTups.isEmpty) { - // CASE 1: Subquery guaranteed not to have the COUNT bug - Project( - currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) - } else { - // Subquery might have the COUNT bug. Add appropriate corrections. - val (topPart, havingNode, aggNode) = splitSubquery(query) - - // The next two cases add a leading column to the outer join input to make it - // possible to distinguish between the case when no tuples join and the case - // when the tuple that joins contains null values. - // The leading column always has the value TRUE. - val alwaysTrueExprId = NamedExpression.newExprId - val alwaysTrueExpr = Alias(Literal.TrueLiteral, - ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) - val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, - BooleanType)(exprId = alwaysTrueExprId) - - val aggValRef = query.output.head - - if (!havingNode.isDefined) { - // CASE 2: Subquery with no HAVING clause - Project( - currentChild.output :+ - Alias( - If(IsNull(alwaysTrueRef), - Literal(resultWithZeroTups.get, origOutput.dataType), - aggValRef), origOutput.name)(exprId = origOutput.exprId), - Join(currentChild, - Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And))) - - } else { - // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. - // Need to modify any operators below the join to pass through all columns - // referenced in the HAVING clause. - var subqueryRoot : UnaryNode = aggNode - val havingInputs : Seq[NamedExpression] = aggNode.output - - topPart.reverse.foreach( - _ match { - case Project(projList, _) => - subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot) - case op@_ => sys.error(s"Unexpected operator $op in corelated subquery") - } - ) - - // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups - // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) - // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen( - Seq[(Expression, Expression)] ( - (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)), - (Not(havingNode.get.condition), Literal(null, aggValRef.dataType)) - ), aggValRef - ), origOutput.name) (exprId = origOutput.exprId) - - Project( - currentChild.output :+ caseExpr, - Join(currentChild, - Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And))) - - } - } + Project( + currentChild.output :+ query.output.head, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1d9ff21dbf..1a99fb683e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -490,85 +490,4 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Nil) } - - test("SPARK-15370: COUNT bug in WHERE clause (Filter)") { - // Case 1: Canonical example of the COUNT bug - checkAnswer( - sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"), - Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) - // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses - // a rewrite that is vulnerable to the COUNT bug - checkAnswer( - sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"), - Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) - // Case 3: COUNT bug without a COUNT aggregate - checkAnswer( - sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"), - Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) - } - - test("SPARK-15370: COUNT bug in SELECT clause (Project)") { - checkAnswer( - sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"), - Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) - :: Row(null, 0) :: Row(6, 1) :: Nil) - } - - test("SPARK-15370: COUNT bug in HAVING clause (Filter)") { - checkAnswer( - sql("select l.a as grp_a from l group by l.a " + - "having (select count(*) from r where grp_a = r.c) = 0 " + - "order by grp_a"), - Row(null) :: Row(1) :: Nil) - } - - test("SPARK-15370: COUNT bug in Aggregate") { - checkAnswer( - sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " + - "from l group by l.a order by aval"), - Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) - } - - test("SPARK-15370: COUNT bug negative examples") { - // Case 1: Potential COUNT bug case that was working correctly prior to the fix - checkAnswer( - sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"), - Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) - // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. - checkAnswer( - sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"), - Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) - // Case 3: COUNT inside aggregate expression but no COUNT bug. - checkAnswer( - sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"), - Nil) - } - - test("SPARK-15370: COUNT bug in subquery in subquery in subquery") { - checkAnswer( - sql("""select l.a from l - |where ( - | select cntPlusOne + 1 as cntPlusTwo from ( - | select cnt + 1 as cntPlusOne from ( - | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 - | ) - | ) - |) = 2""".stripMargin), - Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) - } - - test("SPARK-15370: COUNT bug with nasty predicate expr") { - checkAnswer( - sql("select l.a from l where " + - "(select case when count(*) = 1 then null else count(*) end as cnt " + - "from r where l.a = r.c) = 0"), - Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) - } - - test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { - checkAnswer( - sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), - Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: - Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) - } } |