diff options
author | Wenchen Fan <wenchen@databricks.com> | 2017-01-23 13:31:26 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-01-23 13:31:26 +0800 |
commit | de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13 (patch) | |
tree | 083bf69d6876048d7afd5d04e81e87cd70646081 /sql | |
parent | 772035e771a75593f031a8e78080bb58b8218e04 (diff) | |
download | spark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.tar.gz spark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.tar.bz2 spark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.zip |
[SPARK-19309][SQL] disable common subexpression elimination for conditional expressions
## What changes were proposed in this pull request?
As I pointed out in https://github.com/apache/spark/pull/15807#issuecomment-259143655 , the current subexpression elimination framework has a problem, it always evaluates all common subexpressions at the beginning, even they are inside conditional expressions and may not be accessed.
Ideally we should implement it like scala lazy val, so we only evaluate it when it gets accessed at lease once. https://github.com/apache/spark/issues/15837 tries this approach, but it seems too complicated and may introduce performance regression.
This PR simply stops common subexpression elimination for conditional expressions, with some cleanup.
## How was this patch tested?
regression test
Author: Wenchen Fan <wenchen@databricks.com>
Closes #16659 from cloud-fan/codegen.
Diffstat (limited to 'sql')
7 files changed, 84 insertions, 171 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 6c246a5663..f8644c2cd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -67,28 +67,34 @@ class EquivalentExpressions { /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. - * If ignoreLeaf is true, leaf nodes are ignored. */ - def addExprTree( - root: Expression, - ignoreLeaf: Boolean = true, - skipReferenceToExpressions: Boolean = true): Unit = { - val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) || + def addExprTree(expr: Expression): Unit = { + val skip = expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - root.find(_.isInstanceOf[LambdaVariable]).isDefined - // There are some special expressions that we should not recurse into children. + expr.find(_.isInstanceOf[LambdaVariable]).isDefined + + // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. - val shouldRecurse = root match { - // TODO: some expressions implements `CodegenFallback` but can still do codegen, - // e.g. `CaseWhen`, we should support them. - case _: CodegenFallback => false - case _: ReferenceToExpressions if skipReferenceToExpressions => false - case _ => true + // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // false expressions in `If` may not get accessed, according to the predicate + // expression. We should only recurse into the predicate expression. + // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // condition. We should only recurse into the first condition expression as it + // will always get accessed. + // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // children, because others may not get accessed. + def childrenToRecurse: Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case i: If => i.predicate :: Nil + // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here. + case c: CaseWhenCodegen => c.children.head :: Nil + case c: Coalesce => c.children.head :: Nil + case other => other.children } - if (!skip && !addExpr(root) && shouldRecurse) { - root.children.foreach(addExprTree(_, ignoreLeaf)) + + if (!skip && !addExpr(expr)) { + childrenToRecurse.foreach(addExprTree) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 476e37e6a9..7c57025f99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -117,7 +117,7 @@ object UnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def create(fields: Array[DataType]): UnsafeProjection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala deleted file mode 100644 index 2ca77e8394..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable -import org.apache.spark.sql.types.DataType - -/** - * A special expression that evaluates [[BoundReference]]s by given expressions instead of the - * input row. - * - * @param result The expression that contains [[BoundReference]] and produces the final output. - * @param children The expressions that used as input values for [[BoundReference]]. - */ -case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) - extends Expression { - - override def nullable: Boolean = result.nullable - override def dataType: DataType = result.dataType - - override def checkInputDataTypes(): TypeCheckResult = { - if (result.references.nonEmpty) { - return TypeCheckFailure("The result expression cannot reference to any attributes.") - } - - var maxOrdinal = -1 - result foreach { - case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal > children.length) { - return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + - s"there are only ${children.length} inputs.") - } - - TypeCheckSuccess - } - - private lazy val projection = UnsafeProjection.create(children) - - override def eval(input: InternalRow): Any = { - result.eval(projection(input)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childrenGen = children.map(_.genCode(ctx)) - val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { - case (childGen, child) => - // SPARK-18125: The children vars are local variables. If the result expression uses - // splitExpression, those variables cannot be accessed so compilation fails. - // To fix it, we use class variables to hold those local variables. - val classChildVarName = ctx.freshName("classChildVar") - val classChildVarIsNull = ctx.freshName("classChildVarIsNull") - ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") - ctx.addMutableState("boolean", classChildVarIsNull, "") - - val classChildVar = - LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable) - - val initCode = s"${classChildVar.value} = ${childGen.value};\n" + - s"${classChildVar.isNull} = ${childGen.isNull};" - - (classChildVar, initCode) - }.unzip - - val resultGen = result.transform { - case b: BoundReference => classChildrenVars(b.ordinal) - }.genCode(ctx) - - ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + - resultGen.code, isNull = resultGen.isNull, value = resultGen.value) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f8f868b59b..04b812e79e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -726,7 +726,7 @@ class CodegenContext { val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_, true, false)) + expressions.foreach(equivalentExpressions.addExprTree) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. @@ -734,10 +734,10 @@ class CodegenContext { val codes = commonExprs.map { e => val expr = e.head // Generate the code for this expression tree. - val code = expr.genCode(this) - val state = SubExprEliminationState(code.isNull, code.value) + val eval = expr.genCode(this) + val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(subExprEliminationExprs.put(_, state)) - code.code.trim + eval.code.trim } SubExprCodes(codes, subExprEliminationExprs.toMap) } @@ -747,7 +747,7 @@ class CodegenContext { * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ - private def subexpressionElimination(expressions: Seq[Expression]) = { + private def subexpressionElimination(expressions: Seq[Expression]): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -761,13 +761,13 @@ class CodegenContext { val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. - val code = expr.genCode(this) + val eval = expr.genCode(this) val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${code.code.trim} - | $isNull = ${code.isNull}; - | $value = ${code.value}; + | ${eval.code.trim} + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; |} """.stripMargin @@ -780,9 +780,6 @@ class CodegenContext { // The cost of doing subexpression elimination is: // 1. Extra function call, although this is probably *good* as the JIT can decide to // inline or not. - // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly - // very often. The reason it is not loaded is because of a prior branch. - // 3. Extra store into isLoaded. // The benefit doing subexpression elimination is: // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 // above. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 587022f0a2..7ea0bec145 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-17160: field names are properly escaped by AssertTrue") { GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) } + + test("should not apply common subexpression elimination on conditional expressions") { + val row = InternalRow(null) + val bound = BoundReference(0, IntegerType, true) + val assertNotNull = AssertNotNull(bound, Nil) + val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull)) + val projection = GenerateUnsafeProjection.generate( + Seq(expr), subexpressionEliminationEnabled = true) + // should not throw exception + projection(row) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 2db2a043e5..c48730bd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val add2 = Add(add, add) var equivalence = new EquivalentExpressions - equivalence.addExprTree(add, true) - equivalence.addExprTree(abs, true) - equivalence.addExprTree(add2, true) + equivalence.addExprTree(add) + equivalence.addExprTree(abs) + equivalence.addExprTree(add2) // Should only have one equivalence for `one + two` assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1) @@ -115,10 +115,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val mul2 = Multiply(mul, mul) val sqrt = Sqrt(mul2) val sum = Add(mul2, sqrt) - equivalence.addExprTree(mul, true) - equivalence.addExprTree(mul2, true) - equivalence.addExprTree(sqrt, true) - equivalence.addExprTree(sum, true) + equivalence.addExprTree(mul) + equivalence.addExprTree(mul2) + equivalence.addExprTree(sqrt) + equivalence.addExprTree(sum) // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3) @@ -126,30 +126,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getEquivalentExprs(mul2).size == 3) assert(equivalence.getEquivalentExprs(sqrt).size == 2) assert(equivalence.getEquivalentExprs(sum).size == 1) - - // Some expressions inspired by TPCH-Q1 - // sum(l_quantity) as sum_qty, - // sum(l_extendedprice) as sum_base_price, - // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, - // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, - // avg(l_extendedprice) as avg_price, - // avg(l_discount) as avg_disc - equivalence = new EquivalentExpressions - val quantity = Literal(1) - val price = Literal(1.1) - val discount = Literal(.24) - val tax = Literal(0.1) - equivalence.addExprTree(quantity, false) - equivalence.addExprTree(price, false) - equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) - equivalence.addExprTree( - Multiply( - Multiply(price, Subtract(Literal(1), discount)), - Add(Literal(1), tax)), false) - equivalence.addExprTree(price, false) - equivalence.addExprTree(discount, false) - // quantity, price, discount and (price * (1 - discount)) - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4) } test("Expression equivalence - non deterministic") { @@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val add = Add(two, fallback) val equivalence = new EquivalentExpressions - equivalence.addExprTree(add, true) + equivalence.addExprTree(add) // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } + + test("Children of conditional expressions") { + val condition = And(Literal(true), Literal(false)) + val add = Add(Literal(1), Literal(2)) + val ifExpr = If(condition, add, add) + + val equivalence = new EquivalentExpressions + equivalence.addExprTree(ifExpr) + // the `add` inside `If` should not be added + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) + // only ifExpr and its predicate expression + assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 4146bf3269..717758fdf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression( override lazy val aggBufferAttributes: Seq[AttributeReference] = bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) + private def serializeToBuffer(expr: Expression): Seq[Expression] = { + bufferSerializer.map(_.transform { + case _: BoundReference => expr + }) + } + override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil)) + serializeToBuffer(zero) } override lazy val updateExpressions: Seq[Expression] = { @@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression( "reduce", bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - - bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil)) + serializeToBuffer(reduced) } override lazy val mergeExpressions: Seq[Expression] = { @@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression( "merge", bufferExternalType, leftBuffer :: rightBuffer :: Nil) - - bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil)) + serializeToBuffer(merged) } override lazy val evaluateExpression: Expression = { @@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression( outputExternalType, bufferDeserializer :: Nil) + val outputSerializeExprs = outputSerializer.map(_.transform { + case _: BoundReference => resultObj + }) + dataType match { - case s: StructType => + case _: StructType => val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get - val struct = If( - IsNull(objRef), - Literal.create(null, dataType), - CreateStruct(outputSerializer)) - ReferenceToExpressions(struct, resultObj :: Nil) + If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs)) case _ => - assert(outputSerializer.length == 1) - outputSerializer.head transform { - case b: BoundReference => resultObj - } + assert(outputSerializeExprs.length == 1) + outputSerializeExprs.head } } |