aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala92
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala32
7 files changed, 84 insertions, 171 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 6c246a5663..f8644c2cd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -67,28 +67,34 @@ class EquivalentExpressions {
/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
- * If ignoreLeaf is true, leaf nodes are ignored.
*/
- def addExprTree(
- root: Expression,
- ignoreLeaf: Boolean = true,
- skipReferenceToExpressions: Boolean = true): Unit = {
- val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
+ def addExprTree(expr: Expression): Unit = {
+ val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
- root.find(_.isInstanceOf[LambdaVariable]).isDefined
- // There are some special expressions that we should not recurse into children.
+ expr.find(_.isInstanceOf[LambdaVariable]).isDefined
+
+ // There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
- // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
- val shouldRecurse = root match {
- // TODO: some expressions implements `CodegenFallback` but can still do codegen,
- // e.g. `CaseWhen`, we should support them.
- case _: CodegenFallback => false
- case _: ReferenceToExpressions if skipReferenceToExpressions => false
- case _ => true
+ // 2. If: common subexpressions will always be evaluated at the beginning, but the true and
+ // false expressions in `If` may not get accessed, according to the predicate
+ // expression. We should only recurse into the predicate expression.
+ // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
+ // condition. We should only recurse into the first condition expression as it
+ // will always get accessed.
+ // 4. Coalesce: it's also a conditional expression, we should only recurse into the first
+ // children, because others may not get accessed.
+ def childrenToRecurse: Seq[Expression] = expr match {
+ case _: CodegenFallback => Nil
+ case i: If => i.predicate :: Nil
+ // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
+ case c: CaseWhenCodegen => c.children.head :: Nil
+ case c: Coalesce => c.children.head :: Nil
+ case other => other.children
}
- if (!skip && !addExpr(root) && shouldRecurse) {
- root.children.foreach(addExprTree(_, ignoreLeaf))
+
+ if (!skip && !addExpr(expr)) {
+ childrenToRecurse.foreach(addExprTree)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 476e37e6a9..7c57025f99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -117,7 +117,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
- create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
deleted file mode 100644
index 2ca77e8394..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
-import org.apache.spark.sql.types.DataType
-
-/**
- * A special expression that evaluates [[BoundReference]]s by given expressions instead of the
- * input row.
- *
- * @param result The expression that contains [[BoundReference]] and produces the final output.
- * @param children The expressions that used as input values for [[BoundReference]].
- */
-case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
- extends Expression {
-
- override def nullable: Boolean = result.nullable
- override def dataType: DataType = result.dataType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (result.references.nonEmpty) {
- return TypeCheckFailure("The result expression cannot reference to any attributes.")
- }
-
- var maxOrdinal = -1
- result foreach {
- case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
- case _ =>
- }
- if (maxOrdinal > children.length) {
- return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
- s"there are only ${children.length} inputs.")
- }
-
- TypeCheckSuccess
- }
-
- private lazy val projection = UnsafeProjection.create(children)
-
- override def eval(input: InternalRow): Any = {
- result.eval(projection(input))
- }
-
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val childrenGen = children.map(_.genCode(ctx))
- val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
- case (childGen, child) =>
- // SPARK-18125: The children vars are local variables. If the result expression uses
- // splitExpression, those variables cannot be accessed so compilation fails.
- // To fix it, we use class variables to hold those local variables.
- val classChildVarName = ctx.freshName("classChildVar")
- val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
- ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
- ctx.addMutableState("boolean", classChildVarIsNull, "")
-
- val classChildVar =
- LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable)
-
- val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
- s"${classChildVar.isNull} = ${childGen.isNull};"
-
- (classChildVar, initCode)
- }.unzip
-
- val resultGen = result.transform {
- case b: BoundReference => classChildrenVars(b.ordinal)
- }.genCode(ctx)
-
- ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
- resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f8f868b59b..04b812e79e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -726,7 +726,7 @@ class CodegenContext {
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
// Add each expression tree and compute the common subexpressions.
- expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
+ expressions.foreach(equivalentExpressions.addExprTree)
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
@@ -734,10 +734,10 @@ class CodegenContext {
val codes = commonExprs.map { e =>
val expr = e.head
// Generate the code for this expression tree.
- val code = expr.genCode(this)
- val state = SubExprEliminationState(code.isNull, code.value)
+ val eval = expr.genCode(this)
+ val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(subExprEliminationExprs.put(_, state))
- code.code.trim
+ eval.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
@@ -747,7 +747,7 @@ class CodegenContext {
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
- private def subexpressionElimination(expressions: Seq[Expression]) = {
+ private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))
@@ -761,13 +761,13 @@ class CodegenContext {
val value = s"${fnName}Value"
// Generate the code for this expression tree and wrap it in a function.
- val code = expr.genCode(this)
+ val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
- | ${code.code.trim}
- | $isNull = ${code.isNull};
- | $value = ${code.value};
+ | ${eval.code.trim}
+ | $isNull = ${eval.isNull};
+ | $value = ${eval.value};
|}
""".stripMargin
@@ -780,9 +780,6 @@ class CodegenContext {
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
- // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
- // very often. The reason it is not loaded is because of a prior branch.
- // 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 587022f0a2..7ea0bec145 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
@@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
}
+
+ test("should not apply common subexpression elimination on conditional expressions") {
+ val row = InternalRow(null)
+ val bound = BoundReference(0, IntegerType, true)
+ val assertNotNull = AssertNotNull(bound, Nil)
+ val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
+ val projection = GenerateUnsafeProjection.generate(
+ Seq(expr), subexpressionEliminationEnabled = true)
+ // should not throw exception
+ projection(row)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 2db2a043e5..c48730bd9d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add2 = Add(add, add)
var equivalence = new EquivalentExpressions
- equivalence.addExprTree(add, true)
- equivalence.addExprTree(abs, true)
- equivalence.addExprTree(add2, true)
+ equivalence.addExprTree(add)
+ equivalence.addExprTree(abs)
+ equivalence.addExprTree(add2)
// Should only have one equivalence for `one + two`
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
@@ -115,10 +115,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)
- equivalence.addExprTree(mul, true)
- equivalence.addExprTree(mul2, true)
- equivalence.addExprTree(sqrt, true)
- equivalence.addExprTree(sum, true)
+ equivalence.addExprTree(mul)
+ equivalence.addExprTree(mul2)
+ equivalence.addExprTree(sqrt)
+ equivalence.addExprTree(sum)
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
@@ -126,30 +126,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
assert(equivalence.getEquivalentExprs(mul2).size == 3)
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
assert(equivalence.getEquivalentExprs(sum).size == 1)
-
- // Some expressions inspired by TPCH-Q1
- // sum(l_quantity) as sum_qty,
- // sum(l_extendedprice) as sum_base_price,
- // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
- // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
- // avg(l_extendedprice) as avg_price,
- // avg(l_discount) as avg_disc
- equivalence = new EquivalentExpressions
- val quantity = Literal(1)
- val price = Literal(1.1)
- val discount = Literal(.24)
- val tax = Literal(0.1)
- equivalence.addExprTree(quantity, false)
- equivalence.addExprTree(price, false)
- equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
- equivalence.addExprTree(
- Multiply(
- Multiply(price, Subtract(Literal(1), discount)),
- Add(Literal(1), tax)), false)
- equivalence.addExprTree(price, false)
- equivalence.addExprTree(discount, false)
- // quantity, price, discount and (price * (1 - discount))
- assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4)
}
test("Expression equivalence - non deterministic") {
@@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add = Add(two, fallback)
val equivalence = new EquivalentExpressions
- equivalence.addExprTree(add, true)
+ equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
+
+ test("Children of conditional expressions") {
+ val condition = And(Literal(true), Literal(false))
+ val add = Add(Literal(1), Literal(2))
+ val ifExpr = If(condition, add, add)
+
+ val equivalence = new EquivalentExpressions
+ equivalence.addExprTree(ifExpr)
+ // the `add` inside `If` should not be added
+ assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
+ // only ifExpr and its predicate expression
+ assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
+ }
}
case class CodegenFallbackExpression(child: Expression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 4146bf3269..717758fdf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression(
override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])
+ private def serializeToBuffer(expr: Expression): Seq[Expression] = {
+ bufferSerializer.map(_.transform {
+ case _: BoundReference => expr
+ })
+ }
+
override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
- bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
+ serializeToBuffer(zero)
}
override lazy val updateExpressions: Seq[Expression] = {
@@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression(
"reduce",
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)
-
- bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
+ serializeToBuffer(reduced)
}
override lazy val mergeExpressions: Seq[Expression] = {
@@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression(
"merge",
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)
-
- bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
+ serializeToBuffer(merged)
}
override lazy val evaluateExpression: Expression = {
@@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression(
outputExternalType,
bufferDeserializer :: Nil)
+ val outputSerializeExprs = outputSerializer.map(_.transform {
+ case _: BoundReference => resultObj
+ })
+
dataType match {
- case s: StructType =>
+ case _: StructType =>
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
- val struct = If(
- IsNull(objRef),
- Literal.create(null, dataType),
- CreateStruct(outputSerializer))
- ReferenceToExpressions(struct, resultObj :: Nil)
+ If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs))
case _ =>
- assert(outputSerializer.length == 1)
- outputSerializer.head transform {
- case b: BoundReference => resultObj
- }
+ assert(outputSerializeExprs.length == 1)
+ outputSerializeExprs.head
}
}