diff options
Diffstat (limited to 'sql')
7 files changed, 84 insertions, 171 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 6c246a5663..f8644c2cd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -67,28 +67,34 @@ class EquivalentExpressions { /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. - * If ignoreLeaf is true, leaf nodes are ignored. */ - def addExprTree( - root: Expression, - ignoreLeaf: Boolean = true, - skipReferenceToExpressions: Boolean = true): Unit = { - val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) || + def addExprTree(expr: Expression): Unit = { + val skip = expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - root.find(_.isInstanceOf[LambdaVariable]).isDefined - // There are some special expressions that we should not recurse into children. + expr.find(_.isInstanceOf[LambdaVariable]).isDefined + + // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. - val shouldRecurse = root match { - // TODO: some expressions implements `CodegenFallback` but can still do codegen, - // e.g. `CaseWhen`, we should support them. - case _: CodegenFallback => false - case _: ReferenceToExpressions if skipReferenceToExpressions => false - case _ => true + // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // false expressions in `If` may not get accessed, according to the predicate + // expression. We should only recurse into the predicate expression. + // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // condition. We should only recurse into the first condition expression as it + // will always get accessed. + // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // children, because others may not get accessed. + def childrenToRecurse: Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case i: If => i.predicate :: Nil + // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here. + case c: CaseWhenCodegen => c.children.head :: Nil + case c: Coalesce => c.children.head :: Nil + case other => other.children } - if (!skip && !addExpr(root) && shouldRecurse) { - root.children.foreach(addExprTree(_, ignoreLeaf)) + + if (!skip && !addExpr(expr)) { + childrenToRecurse.foreach(addExprTree) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 476e37e6a9..7c57025f99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -117,7 +117,7 @@ object UnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def create(fields: Array[DataType]): UnsafeProjection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala deleted file mode 100644 index 2ca77e8394..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable -import org.apache.spark.sql.types.DataType - -/** - * A special expression that evaluates [[BoundReference]]s by given expressions instead of the - * input row. - * - * @param result The expression that contains [[BoundReference]] and produces the final output. - * @param children The expressions that used as input values for [[BoundReference]]. - */ -case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) - extends Expression { - - override def nullable: Boolean = result.nullable - override def dataType: DataType = result.dataType - - override def checkInputDataTypes(): TypeCheckResult = { - if (result.references.nonEmpty) { - return TypeCheckFailure("The result expression cannot reference to any attributes.") - } - - var maxOrdinal = -1 - result foreach { - case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal > children.length) { - return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + - s"there are only ${children.length} inputs.") - } - - TypeCheckSuccess - } - - private lazy val projection = UnsafeProjection.create(children) - - override def eval(input: InternalRow): Any = { - result.eval(projection(input)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childrenGen = children.map(_.genCode(ctx)) - val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { - case (childGen, child) => - // SPARK-18125: The children vars are local variables. If the result expression uses - // splitExpression, those variables cannot be accessed so compilation fails. - // To fix it, we use class variables to hold those local variables. - val classChildVarName = ctx.freshName("classChildVar") - val classChildVarIsNull = ctx.freshName("classChildVarIsNull") - ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") - ctx.addMutableState("boolean", classChildVarIsNull, "") - - val classChildVar = - LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable) - - val initCode = s"${classChildVar.value} = ${childGen.value};\n" + - s"${classChildVar.isNull} = ${childGen.isNull};" - - (classChildVar, initCode) - }.unzip - - val resultGen = result.transform { - case b: BoundReference => classChildrenVars(b.ordinal) - }.genCode(ctx) - - ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + - resultGen.code, isNull = resultGen.isNull, value = resultGen.value) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f8f868b59b..04b812e79e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -726,7 +726,7 @@ class CodegenContext { val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_, true, false)) + expressions.foreach(equivalentExpressions.addExprTree) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. @@ -734,10 +734,10 @@ class CodegenContext { val codes = commonExprs.map { e => val expr = e.head // Generate the code for this expression tree. - val code = expr.genCode(this) - val state = SubExprEliminationState(code.isNull, code.value) + val eval = expr.genCode(this) + val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(subExprEliminationExprs.put(_, state)) - code.code.trim + eval.code.trim } SubExprCodes(codes, subExprEliminationExprs.toMap) } @@ -747,7 +747,7 @@ class CodegenContext { * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ - private def subexpressionElimination(expressions: Seq[Expression]) = { + private def subexpressionElimination(expressions: Seq[Expression]): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -761,13 +761,13 @@ class CodegenContext { val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. - val code = expr.genCode(this) + val eval = expr.genCode(this) val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${code.code.trim} - | $isNull = ${code.isNull}; - | $value = ${code.value}; + | ${eval.code.trim} + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; |} """.stripMargin @@ -780,9 +780,6 @@ class CodegenContext { // The cost of doing subexpression elimination is: // 1. Extra function call, although this is probably *good* as the JIT can decide to // inline or not. - // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly - // very often. The reason it is not loaded is because of a prior branch. - // 3. Extra store into isLoaded. // The benefit doing subexpression elimination is: // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 // above. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 587022f0a2..7ea0bec145 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-17160: field names are properly escaped by AssertTrue") { GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) } + + test("should not apply common subexpression elimination on conditional expressions") { + val row = InternalRow(null) + val bound = BoundReference(0, IntegerType, true) + val assertNotNull = AssertNotNull(bound, Nil) + val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull)) + val projection = GenerateUnsafeProjection.generate( + Seq(expr), subexpressionEliminationEnabled = true) + // should not throw exception + projection(row) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 2db2a043e5..c48730bd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val add2 = Add(add, add) var equivalence = new EquivalentExpressions - equivalence.addExprTree(add, true) - equivalence.addExprTree(abs, true) - equivalence.addExprTree(add2, true) + equivalence.addExprTree(add) + equivalence.addExprTree(abs) + equivalence.addExprTree(add2) // Should only have one equivalence for `one + two` assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1) @@ -115,10 +115,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val mul2 = Multiply(mul, mul) val sqrt = Sqrt(mul2) val sum = Add(mul2, sqrt) - equivalence.addExprTree(mul, true) - equivalence.addExprTree(mul2, true) - equivalence.addExprTree(sqrt, true) - equivalence.addExprTree(sum, true) + equivalence.addExprTree(mul) + equivalence.addExprTree(mul2) + equivalence.addExprTree(sqrt) + equivalence.addExprTree(sum) // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3) @@ -126,30 +126,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getEquivalentExprs(mul2).size == 3) assert(equivalence.getEquivalentExprs(sqrt).size == 2) assert(equivalence.getEquivalentExprs(sum).size == 1) - - // Some expressions inspired by TPCH-Q1 - // sum(l_quantity) as sum_qty, - // sum(l_extendedprice) as sum_base_price, - // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, - // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, - // avg(l_extendedprice) as avg_price, - // avg(l_discount) as avg_disc - equivalence = new EquivalentExpressions - val quantity = Literal(1) - val price = Literal(1.1) - val discount = Literal(.24) - val tax = Literal(0.1) - equivalence.addExprTree(quantity, false) - equivalence.addExprTree(price, false) - equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) - equivalence.addExprTree( - Multiply( - Multiply(price, Subtract(Literal(1), discount)), - Add(Literal(1), tax)), false) - equivalence.addExprTree(price, false) - equivalence.addExprTree(discount, false) - // quantity, price, discount and (price * (1 - discount)) - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4) } test("Expression equivalence - non deterministic") { @@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val add = Add(two, fallback) val equivalence = new EquivalentExpressions - equivalence.addExprTree(add, true) + equivalence.addExprTree(add) // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } + + test("Children of conditional expressions") { + val condition = And(Literal(true), Literal(false)) + val add = Add(Literal(1), Literal(2)) + val ifExpr = If(condition, add, add) + + val equivalence = new EquivalentExpressions + equivalence.addExprTree(ifExpr) + // the `add` inside `If` should not be added + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) + // only ifExpr and its predicate expression + assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 4146bf3269..717758fdf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression( override lazy val aggBufferAttributes: Seq[AttributeReference] = bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) + private def serializeToBuffer(expr: Expression): Seq[Expression] = { + bufferSerializer.map(_.transform { + case _: BoundReference => expr + }) + } + override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil)) + serializeToBuffer(zero) } override lazy val updateExpressions: Seq[Expression] = { @@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression( "reduce", bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - - bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil)) + serializeToBuffer(reduced) } override lazy val mergeExpressions: Seq[Expression] = { @@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression( "merge", bufferExternalType, leftBuffer :: rightBuffer :: Nil) - - bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil)) + serializeToBuffer(merged) } override lazy val evaluateExpression: Expression = { @@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression( outputExternalType, bufferDeserializer :: Nil) + val outputSerializeExprs = outputSerializer.map(_.transform { + case _: BoundReference => resultObj + }) + dataType match { - case s: StructType => + case _: StructType => val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get - val struct = If( - IsNull(objRef), - Literal.create(null, dataType), - CreateStruct(outputSerializer)) - ReferenceToExpressions(struct, resultObj :: Nil) + If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs)) case _ => - assert(outputSerializer.length == 1) - outputSerializer.head transform { - case b: BoundReference => resultObj - } + assert(outputSerializeExprs.length == 1) + outputSerializeExprs.head } } |