path: root/sql
diff options
Diffstat (limited to 'sql')
11 files changed, 523 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
new file mode 100644
index 0000000000..e7380d21f9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -0,0 +1,106 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+ * This class is used to compute equality of (sub)expression trees. Expressions can be added
+ * to this class and they subsequently query for expression equality. Expression trees are
+ * considered equal if for the same input(s), the same result is produced.
+ */
+class EquivalentExpressions {
+ /**
+ * Wrapper around an Expression that provides semantic equality.
+ */
+ case class Expr(e: Expression) {
+ val hash = e.semanticHash()
+ override def equals(o: Any): Boolean = o match {
+ case other: Expr => e.semanticEquals(other.e)
+ case _ => false
+ }
+ override def hashCode: Int = hash
+ }
+ // For each expression, the set of equivalent expressions.
+ private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] =
+ new mutable.HashMap[Expr, mutable.MutableList[Expression]]
+ /**
+ * Adds each expression to this data structure, grouping them with existing equivalent
+ * expressions. Non-recursive.
+ * Returns if there was already a matching expression.
+ */
+ def addExpr(expr: Expression): Boolean = {
+ if (expr.deterministic) {
+ val e: Expr = Expr(expr)
+ val f = equivalenceMap.get(e)
+ if (f.isDefined) {
+ f.get.+= (expr)
+ true
+ } else {
+ equivalenceMap.put(e, mutable.MutableList(expr))
+ false
+ }
+ } else {
+ false
+ }
+ }
+ /**
+ * Adds the expression to this datastructure recursively. Stops if a matching expression
+ * is found. That is, if `expr` has already been added, its children are not added.
+ * If ignoreLeaf is true, leaf nodes are ignored.
+ */
+ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
+ val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
+ if (!skip && root.deterministic && !addExpr(root)) {
+ root.children.foreach(addExprTree(_, ignoreLeaf))
+ }
+ }
+ /**
+ * Returns all fo the expression trees that are equivalent to `e`. Returns
+ * an empty collection if there are none.
+ */
+ def getEquivalentExprs(e: Expression): Seq[Expression] = {
+ equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
+ }
+ /**
+ * Returns all the equivalent sets of expressions.
+ */
+ def getAllEquivalentExprs: Seq[Seq[Expression]] = {
+ equivalenceMap.values.map(_.toSeq).toSeq
+ }
+ /**
+ * Returns the state of the datastructure as a string. If all is false, skips sets of equivalent
+ * expressions with cardinality 1.
+ */
+ def debugString(all: Boolean = false): String = {
+ val sb: mutable.StringBuilder = new StringBuilder()
+ sb.append("Equivalent expressions:\n")
+ equivalenceMap.foreach { case (k, v) => {
+ if (all || v.length > 1) {
+ sb.append(" " + v.mkString(", ")).append("\n")
+ }
+ }}
+ sb.toString()
+ }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 96fcc799e5..7d5741eefc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
- val isNull = ctx.freshName("isNull")
- val primitive = ctx.freshName("primitive")
- val ve = GeneratedExpressionCode("", isNull, primitive)
- ve.code = genCode(ctx, ve)
- // Add `this` in the comment.
- ve.copy(s"/* $this */\n" + ve.code)
+ val subExprState = ctx.subExprEliminationExprs.get(this)
+ if (subExprState.isDefined) {
+ // This expression is repeated meaning the code to evaluated has already been added
+ // as a function, `subExprState.fnName`. Just call that.
+ val code =
+ s"""
+ |/* $this */
+ |${subExprState.get.fnName}(${ctx.INPUT_ROW});
+ |""".stripMargin.trim
+ GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value)
+ } else {
+ val isNull = ctx.freshName("isNull")
+ val primitive = ctx.freshName("primitive")
+ val ve = GeneratedExpressionCode("", isNull, primitive)
+ ve.code = genCode(ctx, ve)
+ // Add `this` in the comment.
+ ve.copy(s"/* $this */\n" + ve.code.trim)
+ }
@@ -145,12 +157,38 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
+ // Non-determinstic expressions cannot be equal
+ if (!deterministic || !other.deterministic) return false
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
+ * Returns the hash for this expression. Expressions that compute the same result, even if
+ * they differ cosmetically should return the same hash.
+ */
+ def semanticHash() : Int = {
+ def computeHash(e: Seq[Any]): Int = {
+ // See http://stackoverflow.com/questions/113511/hash-code-implementation
+ var hash: Int = 17
+ e.foreach(i => {
+ val h: Int = i match {
+ case (e: Expression) => e.semanticHash()
+ case (Some(e: Expression)) => e.semanticHash()
+ case (t: Traversable[_]) => computeHash(t.toSeq)
+ case null => 0
+ case (o) => o.hashCode()
+ }
+ hash = hash * 37 + h
+ })
+ hash
+ }
+ computeHash(this.productIterator.toSeq)
+ }
+ /**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 79dabe8e92..9f0b7821ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -144,6 +144,22 @@ object UnsafeProjection {
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
+ /**
+ * Same as other create()'s but allowing enabling/disabling subexpression elimination.
+ * TODO: refactor the plumbing and clean this up.
+ */
+ def create(
+ exprs: Seq[Expression],
+ inputSchema: Seq[Attribute],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ val e = exprs.map(BindReferences.bindReference(_, inputSchema))
+ .map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ })
+ GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
+ }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f0f7a6cf0c..60a3d60184 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -92,6 +92,33 @@ class CodeGenContext {
addedFunctions += ((funcName, funcCode))
+ /**
+ * Holds expressions that are equivalent. Used to perform subexpression elimination
+ * during codegen.
+ *
+ * For expressions that appear more than once, generate additional code to prevent
+ * recomputing the value.
+ *
+ * For example, consider two exprsesion generated from this SQL statement:
+ * SELECT (col1 + col2), (col1 + col2) / col3.
+ *
+ * equivalentExpressions will match the tree containing `col1 + col2` and it will only
+ * be evaluated once.
+ */
+ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+ // State used for subexpression elimination.
+ case class SubExprEliminationState(
+ val isLoaded: String, code: GeneratedExpressionCode, val fnName: String)
+ // Foreach expression that is participating in subexpression elimination, the state to use.
+ val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] =
+ mutable.HashMap[Expression, SubExprEliminationState]()
+ // The collection of isLoaded variables that need to be reset on each row.
+ val subExprIsLoadedVariables: mutable.ArrayBuffer[String] =
+ mutable.ArrayBuffer.empty[String]
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -317,6 +344,87 @@ class CodeGenContext {
functions.map(name => s"$name($row);").mkString("\n")
+ /**
+ * Checks and sets up the state and codegen for subexpression elimination. This finds the
+ * common subexpresses, generates the functions that evaluate those expressions and populates
+ * the mapping of common subexpressions to the generated functions.
+ */
+ private def subexpressionElimination(expressions: Seq[Expression]) = {
+ // Add each expression tree and compute the common subexpressions.
+ expressions.foreach(equivalentExpressions.addExprTree(_))
+ // Get all the exprs that appear at least twice and set up the state for subexpression
+ // elimination.
+ val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
+ commonExprs.foreach(e => {
+ val expr = e.head
+ val isLoaded = freshName("isLoaded")
+ val isNull = freshName("isNull")
+ val primitive = freshName("primitive")
+ val fnName = freshName("evalExpr")
+ // Generate the code for this expression tree and wrap it in a function.
+ val code = expr.gen(this)
+ val fn =
+ s"""
+ |private void $fnName(InternalRow ${INPUT_ROW}) {
+ | if (!$isLoaded) {
+ | ${code.code.trim}
+ | $isLoaded = true;
+ | $isNull = ${code.isNull};
+ | $primitive = ${code.value};
+ | }
+ |}
+ """.stripMargin
+ code.code = fn
+ code.isNull = isNull
+ code.value = primitive
+ addNewFunction(fnName, fn)
+ // Add a state and a mapping of the common subexpressions that are associate with this
+ // state. Adding this expression to subExprEliminationExprMap means it will call `fn`
+ // when it is code generated. This decision should be a cost based one.
+ //
+ // The cost of doing subexpression elimination is:
+ // 1. Extra function call, although this is probably *good* as the JIT can decide to
+ // inline or not.
+ // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
+ // very often. The reason it is not loaded is because of a prior branch.
+ // 3. Extra store into isLoaded.
+ // The benefit doing subexpression elimination is:
+ // 1. Running the expression logic. Even for a simple expression, it is likely more than 3
+ // above.
+ // 2. Less code.
+ // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
+ // at least two nodes) as the cost of doing it is expected to be low.
+ // Maintain the loaded value and isNull as member variables. This is necessary if the codegen
+ // function is split across multiple functions.
+ // TODO: maintaining this as a local variable probably allows the compiler to do better
+ // optimizations.
+ addMutableState("boolean", isLoaded, s"$isLoaded = false;")
+ addMutableState("boolean", isNull, s"$isNull = false;")
+ addMutableState(javaType(expr.dataType), primitive,
+ s"$primitive = ${defaultValue(expr.dataType)};")
+ subExprIsLoadedVariables += isLoaded
+ val state = SubExprEliminationState(isLoaded, code, fnName)
+ e.foreach(subExprEliminationExprs.put(_, state))
+ })
+ }
+ /**
+ * Generates code for expressions. If doSubexpressionElimination is true, subexpression
+ * elimination will be performed. Subexpression elimination assumes that the code will for each
+ * expression will be combined in the `expressions` order.
+ */
+ def generateExpressions(expressions: Seq[Expression],
+ doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = {
+ if (doSubexpressionElimination) subexpressionElimination(expressions)
+ expressions.map(e => e.gen(this))
+ }
@@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
- ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+ ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 2136f82ba4..9ef2261414 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
if (${input.isNull}) {
- $setNull
+ ${setNull.trim}
} else {
- $writeField
+ ${writeField.trim}
@@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.initialize($bufferHolder, ${inputs.length});
${ctx.splitExpressions(row, writeFields)}
- """
+ """.trim
// TODO: if the nullability of array element is correct, we can use it to save null check.
@@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
- def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
- val exprEvals = expressions.map(e => e.gen(ctx))
+ def createCode(
+ ctx: CodeGenContext,
+ expressions: Seq[Expression],
+ useSubexprElimination: Boolean = false): GeneratedExpressionCode = {
+ val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)
val result = ctx.freshName("result")
@@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
+ // Reset the isLoaded flag for each row.
+ val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n")
val code =
+ $subexprReset
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
$result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
GeneratedExpressionCode(code, "false", result)
@@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
+ def generate(
+ expressions: Seq[Expression],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ create(canonicalize(expressions), subexpressionEliminationEnabled)
+ }
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
- val ctx = newCodeGenContext()
+ create(expressions, false)
+ }
- val eval = createCode(ctx, expressions)
+ private def create(
+ expressions: Seq[Expression],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ val ctx = newCodeGenContext()
+ val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
val code = s"""
public Object generate($exprType[] exprs) {
@@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private $exprType[] expressions;
public SpecificUnsafeProjection($exprType[] expressions) {
@@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
- ${eval.code}
+ ${eval.code.trim}
return ${eval.value};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 9ab5c299d0..f80bcfcb0b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -203,6 +203,10 @@ case class AttributeReference(
case _ => false
+ override def semanticHash(): Int = {
+ this.exprId.hashCode()
+ }
override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
new file mode 100644
index 0000000000..9de066e99d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -0,0 +1,153 @@
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.IntegerType
+class SubexpressionEliminationSuite extends SparkFunSuite {
+ test("Semantic equals and hash") {
+ val id = ExprId(1)
+ val a: AttributeReference = AttributeReference("name", IntegerType)()
+ val b1 = a.withName("name2").withExprId(id)
+ val b2 = a.withExprId(id)
+ assert(b1 != b2)
+ assert(a != b1)
+ assert(b1.semanticEquals(b2))
+ assert(!b1.semanticEquals(a))
+ assert(a.hashCode != b1.hashCode)
+ assert(b1.hashCode == b2.hashCode)
+ assert(b1.semanticHash() == b2.semanticHash())
+ }
+ test("Expression Equivalence - basic") {
+ val equivalence = new EquivalentExpressions
+ assert(equivalence.getAllEquivalentExprs.isEmpty)
+ val oneA = Literal(1)
+ val oneB = Literal(1)
+ val twoA = Literal(2)
+ var twoB = Literal(2)
+ assert(equivalence.getEquivalentExprs(oneA).isEmpty)
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+ // Add oneA and test if it is returned. Since it is a group of one, it does not.
+ assert(!equivalence.addExpr(oneA))
+ assert(equivalence.getEquivalentExprs(oneA).size == 1)
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+ assert(equivalence.addExpr((oneA)))
+ assert(equivalence.getEquivalentExprs(oneA).size == 2)
+ // Add B and make sure they can see each other.
+ assert(equivalence.addExpr(oneB))
+ // Use exists and reference equality because of how equals is defined.
+ assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB))
+ assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA))
+ assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
+ assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+ assert(equivalence.getAllEquivalentExprs.size == 1)
+ assert(equivalence.getAllEquivalentExprs.head.size == 3)
+ assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
+ assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
+ val add1 = Add(oneA, oneB)
+ val add2 = Add(oneA, oneB)
+ equivalence.addExpr(add1)
+ equivalence.addExpr(add2)
+ assert(equivalence.getAllEquivalentExprs.size == 2)
+ assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
+ assert(equivalence.getEquivalentExprs(add2).size == 2)
+ assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
+ }
+ test("Expression Equivalence - Trees") {
+ val one = Literal(1)
+ val two = Literal(2)
+ val add = Add(one, two)
+ val abs = Abs(add)
+ val add2 = Add(add, add)
+ var equivalence = new EquivalentExpressions
+ equivalence.addExprTree(add, true)
+ equivalence.addExprTree(abs, true)
+ equivalence.addExprTree(add2, true)
+ // Should only have one equivalence for `one + two`
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1)
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4)
+ // Set up the expressions
+ // one * two,
+ // (one * two) * (one * two)
+ // sqrt( (one * two) * (one * two) )
+ // (one * two) + sqrt( (one * two) * (one * two) )
+ equivalence = new EquivalentExpressions
+ val mul = Multiply(one, two)
+ val mul2 = Multiply(mul, mul)
+ val sqrt = Sqrt(mul2)
+ val sum = Add(mul2, sqrt)
+ equivalence.addExprTree(mul, true)
+ equivalence.addExprTree(mul2, true)
+ equivalence.addExprTree(sqrt, true)
+ equivalence.addExprTree(sum, true)
+ // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3)
+ assert(equivalence.getEquivalentExprs(mul).size == 3)
+ assert(equivalence.getEquivalentExprs(mul2).size == 3)
+ assert(equivalence.getEquivalentExprs(sqrt).size == 2)
+ assert(equivalence.getEquivalentExprs(sum).size == 1)
+ // Some expressions inspired by TPCH-Q1
+ // sum(l_quantity) as sum_qty,
+ // sum(l_extendedprice) as sum_base_price,
+ // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
+ // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
+ // avg(l_extendedprice) as avg_price,
+ // avg(l_discount) as avg_disc
+ equivalence = new EquivalentExpressions
+ val quantity = Literal(1)
+ val price = Literal(1.1)
+ val discount = Literal(.24)
+ val tax = Literal(0.1)
+ equivalence.addExprTree(quantity, false)
+ equivalence.addExprTree(price, false)
+ equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
+ equivalence.addExprTree(
+ Multiply(
+ Multiply(price, Subtract(Literal(1), discount)),
+ Add(Literal(1), tax)), false)
+ equivalence.addExprTree(price, false)
+ equivalence.addExprTree(discount, false)
+ // quantity, price, discount and (price * (1 - discount))
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4)
+ }
+ test("Expression equivalence - non deterministic") {
+ val sum = Add(Rand(0), Rand(0))
+ val equivalence = new EquivalentExpressions
+ equivalence.addExpr(sum)
+ equivalence.addExpr(sum)
+ assert(equivalence.getAllEquivalentExprs.isEmpty)
+ }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b7314189b5..89e196c066 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -268,6 +268,11 @@ private[spark] object SQLConf {
doc = "When true, use the new optimized Tungsten physical execution backend.",
isPublic = false)
+ val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled",
+ defaultValue = Some(true), // use CODEGEN_ENABLED as default
+ doc = "When true, common subexpressions will be eliminated.",
+ isPublic = false)
val DIALECT = stringConf(
defaultValue = Some("sql"),
@@ -541,6 +546,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED))
+ private[spark] def subexpressionEliminationEnabled: Boolean =
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 8bb293ae87..8650ac500b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
} else {
+ val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.conf.subexpressionEliminationEnabled
+ } else {
+ false
+ }
* Whether the "prepare" method is called.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 145de0db9e..303d636164 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
- val project = UnsafeProjection.create(projectList, child.output)
+ val project = UnsafeProjection.create(projectList, child.output,
+ subexpressionEliminationEnabled)
iter.map { row =>
numRows += 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 441a0c6d0e..19e850a46f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1970,4 +1970,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+ test("Common subexpression elimination") {
+ // select from a table to prevent constant folding.
+ val df = sql("SELECT a, b from testData2 limit 1")
+ checkAnswer(df, Row(1, 1))
+ checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+ checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+ // This does not work because the expressions get grouped like (a + a) + 1
+ checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+ checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+ // Identity udf that tracks the number of times it is called.
+ val countAcc = sparkContext.accumulator(0, "CallCount")
+ sqlContext.udf.register("testUdf", (x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
+ // is correct.
+ def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
+ countAcc.setValue(0)
+ checkAnswer(df, expectedResult)
+ assert(countAcc.value == expectedCount)
+ }
+ verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+ // Would be nice if semantic equals for `+` understood commutative
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+ // Try disabling it via configuration.
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ }