aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-01-23 13:31:26 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-23 13:31:26 +0800
commitde6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13 (patch)
tree083bf69d6876048d7afd5d04e81e87cd70646081 /sql
parent772035e771a75593f031a8e78080bb58b8218e04 (diff)
downloadspark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.tar.gz
spark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.tar.bz2
spark-de6ad3dfa7f4fdc8bb049f31142df9e5c01e6d13.zip
[SPARK-19309][SQL] disable common subexpression elimination for conditional expressions
## What changes were proposed in this pull request? As I pointed out in https://github.com/apache/spark/pull/15807#issuecomment-259143655 , the current subexpression elimination framework has a problem, it always evaluates all common subexpressions at the beginning, even they are inside conditional expressions and may not be accessed. Ideally we should implement it like scala lazy val, so we only evaluate it when it gets accessed at lease once. https://github.com/apache/spark/issues/15837 tries this approach, but it seems too complicated and may introduce performance regression. This PR simply stops common subexpression elimination for conditional expressions, with some cleanup. ## How was this patch tested? regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #16659 from cloud-fan/codegen.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala92
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala32
7 files changed, 84 insertions, 171 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 6c246a5663..f8644c2cd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -67,28 +67,34 @@ class EquivalentExpressions {
/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
- * If ignoreLeaf is true, leaf nodes are ignored.
*/
- def addExprTree(
- root: Expression,
- ignoreLeaf: Boolean = true,
- skipReferenceToExpressions: Boolean = true): Unit = {
- val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
+ def addExprTree(expr: Expression): Unit = {
+ val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
- root.find(_.isInstanceOf[LambdaVariable]).isDefined
- // There are some special expressions that we should not recurse into children.
+ expr.find(_.isInstanceOf[LambdaVariable]).isDefined
+
+ // There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
- // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
- val shouldRecurse = root match {
- // TODO: some expressions implements `CodegenFallback` but can still do codegen,
- // e.g. `CaseWhen`, we should support them.
- case _: CodegenFallback => false
- case _: ReferenceToExpressions if skipReferenceToExpressions => false
- case _ => true
+ // 2. If: common subexpressions will always be evaluated at the beginning, but the true and
+ // false expressions in `If` may not get accessed, according to the predicate
+ // expression. We should only recurse into the predicate expression.
+ // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
+ // condition. We should only recurse into the first condition expression as it
+ // will always get accessed.
+ // 4. Coalesce: it's also a conditional expression, we should only recurse into the first
+ // children, because others may not get accessed.
+ def childrenToRecurse: Seq[Expression] = expr match {
+ case _: CodegenFallback => Nil
+ case i: If => i.predicate :: Nil
+ // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
+ case c: CaseWhenCodegen => c.children.head :: Nil
+ case c: Coalesce => c.children.head :: Nil
+ case other => other.children
}
- if (!skip && !addExpr(root) && shouldRecurse) {
- root.children.foreach(addExprTree(_, ignoreLeaf))
+
+ if (!skip && !addExpr(expr)) {
+ childrenToRecurse.foreach(addExprTree)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 476e37e6a9..7c57025f99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -117,7 +117,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
- create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
deleted file mode 100644
index 2ca77e8394..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
-import org.apache.spark.sql.types.DataType
-
-/**
- * A special expression that evaluates [[BoundReference]]s by given expressions instead of the
- * input row.
- *
- * @param result The expression that contains [[BoundReference]] and produces the final output.
- * @param children The expressions that used as input values for [[BoundReference]].
- */
-case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
- extends Expression {
-
- override def nullable: Boolean = result.nullable
- override def dataType: DataType = result.dataType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (result.references.nonEmpty) {
- return TypeCheckFailure("The result expression cannot reference to any attributes.")
- }
-
- var maxOrdinal = -1
- result foreach {
- case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
- case _ =>
- }
- if (maxOrdinal > children.length) {
- return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
- s"there are only ${children.length} inputs.")
- }
-
- TypeCheckSuccess
- }
-
- private lazy val projection = UnsafeProjection.create(children)
-
- override def eval(input: InternalRow): Any = {
- result.eval(projection(input))
- }
-
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val childrenGen = children.map(_.genCode(ctx))
- val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
- case (childGen, child) =>
- // SPARK-18125: The children vars are local variables. If the result expression uses
- // splitExpression, those variables cannot be accessed so compilation fails.
- // To fix it, we use class variables to hold those local variables.
- val classChildVarName = ctx.freshName("classChildVar")
- val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
- ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
- ctx.addMutableState("boolean", classChildVarIsNull, "")
-
- val classChildVar =
- LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable)
-
- val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
- s"${classChildVar.isNull} = ${childGen.isNull};"
-
- (classChildVar, initCode)
- }.unzip
-
- val resultGen = result.transform {
- case b: BoundReference => classChildrenVars(b.ordinal)
- }.genCode(ctx)
-
- ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
- resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f8f868b59b..04b812e79e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -726,7 +726,7 @@ class CodegenContext {
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
// Add each expression tree and compute the common subexpressions.
- expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
+ expressions.foreach(equivalentExpressions.addExprTree)
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
@@ -734,10 +734,10 @@ class CodegenContext {
val codes = commonExprs.map { e =>
val expr = e.head
// Generate the code for this expression tree.
- val code = expr.genCode(this)
- val state = SubExprEliminationState(code.isNull, code.value)
+ val eval = expr.genCode(this)
+ val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(subExprEliminationExprs.put(_, state))
- code.code.trim
+ eval.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
@@ -747,7 +747,7 @@ class CodegenContext {
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
- private def subexpressionElimination(expressions: Seq[Expression]) = {
+ private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))
@@ -761,13 +761,13 @@ class CodegenContext {
val value = s"${fnName}Value"
// Generate the code for this expression tree and wrap it in a function.
- val code = expr.genCode(this)
+ val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
- | ${code.code.trim}
- | $isNull = ${code.isNull};
- | $value = ${code.value};
+ | ${eval.code.trim}
+ | $isNull = ${eval.isNull};
+ | $value = ${eval.value};
|}
""".stripMargin
@@ -780,9 +780,6 @@ class CodegenContext {
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
- // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
- // very often. The reason it is not loaded is because of a prior branch.
- // 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 587022f0a2..7ea0bec145 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
@@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
}
+
+ test("should not apply common subexpression elimination on conditional expressions") {
+ val row = InternalRow(null)
+ val bound = BoundReference(0, IntegerType, true)
+ val assertNotNull = AssertNotNull(bound, Nil)
+ val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
+ val projection = GenerateUnsafeProjection.generate(
+ Seq(expr), subexpressionEliminationEnabled = true)
+ // should not throw exception
+ projection(row)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 2db2a043e5..c48730bd9d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add2 = Add(add, add)
var equivalence = new EquivalentExpressions
- equivalence.addExprTree(add, true)
- equivalence.addExprTree(abs, true)
- equivalence.addExprTree(add2, true)
+ equivalence.addExprTree(add)
+ equivalence.addExprTree(abs)
+ equivalence.addExprTree(add2)
// Should only have one equivalence for `one + two`
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
@@ -115,10 +115,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)
- equivalence.addExprTree(mul, true)
- equivalence.addExprTree(mul2, true)
- equivalence.addExprTree(sqrt, true)
- equivalence.addExprTree(sum, true)
+ equivalence.addExprTree(mul)
+ equivalence.addExprTree(mul2)
+ equivalence.addExprTree(sqrt)
+ equivalence.addExprTree(sum)
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
@@ -126,30 +126,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
assert(equivalence.getEquivalentExprs(mul2).size == 3)
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
assert(equivalence.getEquivalentExprs(sum).size == 1)
-
- // Some expressions inspired by TPCH-Q1
- // sum(l_quantity) as sum_qty,
- // sum(l_extendedprice) as sum_base_price,
- // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
- // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
- // avg(l_extendedprice) as avg_price,
- // avg(l_discount) as avg_disc
- equivalence = new EquivalentExpressions
- val quantity = Literal(1)
- val price = Literal(1.1)
- val discount = Literal(.24)
- val tax = Literal(0.1)
- equivalence.addExprTree(quantity, false)
- equivalence.addExprTree(price, false)
- equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
- equivalence.addExprTree(
- Multiply(
- Multiply(price, Subtract(Literal(1), discount)),
- Add(Literal(1), tax)), false)
- equivalence.addExprTree(price, false)
- equivalence.addExprTree(discount, false)
- // quantity, price, discount and (price * (1 - discount))
- assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4)
}
test("Expression equivalence - non deterministic") {
@@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add = Add(two, fallback)
val equivalence = new EquivalentExpressions
- equivalence.addExprTree(add, true)
+ equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
+
+ test("Children of conditional expressions") {
+ val condition = And(Literal(true), Literal(false))
+ val add = Add(Literal(1), Literal(2))
+ val ifExpr = If(condition, add, add)
+
+ val equivalence = new EquivalentExpressions
+ equivalence.addExprTree(ifExpr)
+ // the `add` inside `If` should not be added
+ assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
+ // only ifExpr and its predicate expression
+ assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
+ }
}
case class CodegenFallbackExpression(child: Expression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 4146bf3269..717758fdf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression(
override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])
+ private def serializeToBuffer(expr: Expression): Seq[Expression] = {
+ bufferSerializer.map(_.transform {
+ case _: BoundReference => expr
+ })
+ }
+
override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
- bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
+ serializeToBuffer(zero)
}
override lazy val updateExpressions: Seq[Expression] = {
@@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression(
"reduce",
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)
-
- bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
+ serializeToBuffer(reduced)
}
override lazy val mergeExpressions: Seq[Expression] = {
@@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression(
"merge",
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)
-
- bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
+ serializeToBuffer(merged)
}
override lazy val evaluateExpression: Expression = {
@@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression(
outputExternalType,
bufferDeserializer :: Nil)
+ val outputSerializeExprs = outputSerializer.map(_.transform {
+ case _: BoundReference => resultObj
+ })
+
dataType match {
- case s: StructType =>
+ case _: StructType =>
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
- val struct = If(
- IsNull(objRef),
- Literal.create(null, dataType),
- CreateStruct(outputSerializer))
- ReferenceToExpressions(struct, resultObj :: Nil)
+ If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs))
case _ =>
- assert(outputSerializer.length == 1)
- outputSerializer.head transform {
- case b: BoundReference => resultObj
- }
+ assert(outputSerializeExprs.length == 1)
+ outputSerializeExprs.head
}
}