aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2015-11-10 11:28:53 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-10 11:28:53 -0800
commit87aedc48c01dffbd880e6ca84076ed47c68f88d0 (patch)
tree03427b95d0f7032722373fdd05b5f07b6361d7e0 /sql
parent53600854c270d4c953fe95fbae528740b5cf6603 (diff)
downloadspark-87aedc48c01dffbd880e6ca84076ed47c68f88d0.tar.gz
spark-87aedc48c01dffbd880e6ca84076ed47c68f88d0.tar.bz2
spark-87aedc48c01dffbd880e6ca84076ed47c68f88d0.zip
[SPARK-10371][SQL] Implement subexpr elimination for UnsafeProjections
This patch adds the building blocks for codegening subexpr elimination and implements it end to end for UnsafeProjection. The building blocks can be used to do the same thing for other operators. It introduces some utilities to compute common sub expressions. Expressions can be added to this data structure. The expr and its children will be recursively matched against existing expressions (ones previously added) and grouped into common groups. This is built using the existing `semanticEquals`. It does not understand things like commutative or associative expressions. This can be done as future work. After building this data structure, the codegen process takes advantage of it by: 1. Generating a helper function in the generated class that computes the common subexpression. This is done for all common subexpressions that have at least two occurrences and the expression tree is sufficiently complex. 2. When generating the apply() function, if the helper function exists, call that instead of regenerating the expression tree. Repeated calls to the helper function shortcircuit the evaluation logic. Author: Nong Li <nong@databricks.com> Author: Nong Li <nongli@gmail.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #9480 from nongli/spark-10371.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala106
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala110
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala153
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala48
11 files changed, 523 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
new file mode 100644
index 0000000000..e7380d21f9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.mutable
+
+/**
+ * This class is used to compute equality of (sub)expression trees. Expressions can be added
+ * to this class and they subsequently query for expression equality. Expression trees are
+ * considered equal if for the same input(s), the same result is produced.
+ */
+class EquivalentExpressions {
+ /**
+ * Wrapper around an Expression that provides semantic equality.
+ */
+ case class Expr(e: Expression) {
+ val hash = e.semanticHash()
+ override def equals(o: Any): Boolean = o match {
+ case other: Expr => e.semanticEquals(other.e)
+ case _ => false
+ }
+ override def hashCode: Int = hash
+ }
+
+ // For each expression, the set of equivalent expressions.
+ private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] =
+ new mutable.HashMap[Expr, mutable.MutableList[Expression]]
+
+ /**
+ * Adds each expression to this data structure, grouping them with existing equivalent
+ * expressions. Non-recursive.
+ * Returns if there was already a matching expression.
+ */
+ def addExpr(expr: Expression): Boolean = {
+ if (expr.deterministic) {
+ val e: Expr = Expr(expr)
+ val f = equivalenceMap.get(e)
+ if (f.isDefined) {
+ f.get.+= (expr)
+ true
+ } else {
+ equivalenceMap.put(e, mutable.MutableList(expr))
+ false
+ }
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Adds the expression to this datastructure recursively. Stops if a matching expression
+ * is found. That is, if `expr` has already been added, its children are not added.
+ * If ignoreLeaf is true, leaf nodes are ignored.
+ */
+ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
+ val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
+ if (!skip && root.deterministic && !addExpr(root)) {
+ root.children.foreach(addExprTree(_, ignoreLeaf))
+ }
+ }
+
+ /**
+ * Returns all fo the expression trees that are equivalent to `e`. Returns
+ * an empty collection if there are none.
+ */
+ def getEquivalentExprs(e: Expression): Seq[Expression] = {
+ equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
+ }
+
+ /**
+ * Returns all the equivalent sets of expressions.
+ */
+ def getAllEquivalentExprs: Seq[Seq[Expression]] = {
+ equivalenceMap.values.map(_.toSeq).toSeq
+ }
+
+ /**
+ * Returns the state of the datastructure as a string. If all is false, skips sets of equivalent
+ * expressions with cardinality 1.
+ */
+ def debugString(all: Boolean = false): String = {
+ val sb: mutable.StringBuilder = new StringBuilder()
+ sb.append("Equivalent expressions:\n")
+ equivalenceMap.foreach { case (k, v) => {
+ if (all || v.length > 1) {
+ sb.append(" " + v.mkString(", ")).append("\n")
+ }
+ }}
+ sb.toString()
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 96fcc799e5..7d5741eefc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
- val isNull = ctx.freshName("isNull")
- val primitive = ctx.freshName("primitive")
- val ve = GeneratedExpressionCode("", isNull, primitive)
- ve.code = genCode(ctx, ve)
- // Add `this` in the comment.
- ve.copy(s"/* $this */\n" + ve.code)
+ val subExprState = ctx.subExprEliminationExprs.get(this)
+ if (subExprState.isDefined) {
+ // This expression is repeated meaning the code to evaluated has already been added
+ // as a function, `subExprState.fnName`. Just call that.
+ val code =
+ s"""
+ |/* $this */
+ |${subExprState.get.fnName}(${ctx.INPUT_ROW});
+ |""".stripMargin.trim
+ GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value)
+ } else {
+ val isNull = ctx.freshName("isNull")
+ val primitive = ctx.freshName("primitive")
+ val ve = GeneratedExpressionCode("", isNull, primitive)
+ ve.code = genCode(ctx, ve)
+ // Add `this` in the comment.
+ ve.copy(s"/* $this */\n" + ve.code.trim)
+ }
}
/**
@@ -145,12 +157,38 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
}
}
+ // Non-determinstic expressions cannot be equal
+ if (!deterministic || !other.deterministic) return false
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
}
/**
+ * Returns the hash for this expression. Expressions that compute the same result, even if
+ * they differ cosmetically should return the same hash.
+ */
+ def semanticHash() : Int = {
+ def computeHash(e: Seq[Any]): Int = {
+ // See http://stackoverflow.com/questions/113511/hash-code-implementation
+ var hash: Int = 17
+ e.foreach(i => {
+ val h: Int = i match {
+ case (e: Expression) => e.semanticHash()
+ case (Some(e: Expression)) => e.semanticHash()
+ case (t: Traversable[_]) => computeHash(t.toSeq)
+ case null => 0
+ case (o) => o.hashCode()
+ }
+ hash = hash * 37 + h
+ })
+ hash
+ }
+
+ computeHash(this.productIterator.toSeq)
+ }
+
+ /**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 79dabe8e92..9f0b7821ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -144,6 +144,22 @@ object UnsafeProjection {
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}
+
+ /**
+ * Same as other create()'s but allowing enabling/disabling subexpression elimination.
+ * TODO: refactor the plumbing and clean this up.
+ */
+ def create(
+ exprs: Seq[Expression],
+ inputSchema: Seq[Attribute],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ val e = exprs.map(BindReferences.bindReference(_, inputSchema))
+ .map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ })
+ GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f0f7a6cf0c..60a3d60184 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -92,6 +92,33 @@ class CodeGenContext {
addedFunctions += ((funcName, funcCode))
}
+ /**
+ * Holds expressions that are equivalent. Used to perform subexpression elimination
+ * during codegen.
+ *
+ * For expressions that appear more than once, generate additional code to prevent
+ * recomputing the value.
+ *
+ * For example, consider two exprsesion generated from this SQL statement:
+ * SELECT (col1 + col2), (col1 + col2) / col3.
+ *
+ * equivalentExpressions will match the tree containing `col1 + col2` and it will only
+ * be evaluated once.
+ */
+ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+
+ // State used for subexpression elimination.
+ case class SubExprEliminationState(
+ val isLoaded: String, code: GeneratedExpressionCode, val fnName: String)
+
+ // Foreach expression that is participating in subexpression elimination, the state to use.
+ val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] =
+ mutable.HashMap[Expression, SubExprEliminationState]()
+
+ // The collection of isLoaded variables that need to be reset on each row.
+ val subExprIsLoadedVariables: mutable.ArrayBuffer[String] =
+ mutable.ArrayBuffer.empty[String]
+
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -317,6 +344,87 @@ class CodeGenContext {
functions.map(name => s"$name($row);").mkString("\n")
}
}
+
+ /**
+ * Checks and sets up the state and codegen for subexpression elimination. This finds the
+ * common subexpresses, generates the functions that evaluate those expressions and populates
+ * the mapping of common subexpressions to the generated functions.
+ */
+ private def subexpressionElimination(expressions: Seq[Expression]) = {
+ // Add each expression tree and compute the common subexpressions.
+ expressions.foreach(equivalentExpressions.addExprTree(_))
+
+ // Get all the exprs that appear at least twice and set up the state for subexpression
+ // elimination.
+ val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
+ commonExprs.foreach(e => {
+ val expr = e.head
+ val isLoaded = freshName("isLoaded")
+ val isNull = freshName("isNull")
+ val primitive = freshName("primitive")
+ val fnName = freshName("evalExpr")
+
+ // Generate the code for this expression tree and wrap it in a function.
+ val code = expr.gen(this)
+ val fn =
+ s"""
+ |private void $fnName(InternalRow ${INPUT_ROW}) {
+ | if (!$isLoaded) {
+ | ${code.code.trim}
+ | $isLoaded = true;
+ | $isNull = ${code.isNull};
+ | $primitive = ${code.value};
+ | }
+ |}
+ """.stripMargin
+ code.code = fn
+ code.isNull = isNull
+ code.value = primitive
+
+ addNewFunction(fnName, fn)
+
+ // Add a state and a mapping of the common subexpressions that are associate with this
+ // state. Adding this expression to subExprEliminationExprMap means it will call `fn`
+ // when it is code generated. This decision should be a cost based one.
+ //
+ // The cost of doing subexpression elimination is:
+ // 1. Extra function call, although this is probably *good* as the JIT can decide to
+ // inline or not.
+ // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
+ // very often. The reason it is not loaded is because of a prior branch.
+ // 3. Extra store into isLoaded.
+ // The benefit doing subexpression elimination is:
+ // 1. Running the expression logic. Even for a simple expression, it is likely more than 3
+ // above.
+ // 2. Less code.
+ // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
+ // at least two nodes) as the cost of doing it is expected to be low.
+
+ // Maintain the loaded value and isNull as member variables. This is necessary if the codegen
+ // function is split across multiple functions.
+ // TODO: maintaining this as a local variable probably allows the compiler to do better
+ // optimizations.
+ addMutableState("boolean", isLoaded, s"$isLoaded = false;")
+ addMutableState("boolean", isNull, s"$isNull = false;")
+ addMutableState(javaType(expr.dataType), primitive,
+ s"$primitive = ${defaultValue(expr.dataType)};")
+ subExprIsLoadedVariables += isLoaded
+
+ val state = SubExprEliminationState(isLoaded, code, fnName)
+ e.foreach(subExprEliminationExprs.put(_, state))
+ })
+ }
+
+ /**
+ * Generates code for expressions. If doSubexpressionElimination is true, subexpression
+ * elimination will be performed. Subexpression elimination assumes that the code will for each
+ * expression will be combined in the `expressions` order.
+ */
+ def generateExpressions(expressions: Seq[Expression],
+ doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = {
+ if (doSubexpressionElimination) subexpressionElimination(expressions)
+ expressions.map(e => e.gen(this))
+ }
}
/**
@@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
- ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+ ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 2136f82ba4..9ef2261414 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
${input.code}
if (${input.isNull}) {
- $setNull
+ ${setNull.trim}
} else {
- $writeField
+ ${writeField.trim}
}
"""
}
@@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
$rowWriter.initialize($bufferHolder, ${inputs.length});
${ctx.splitExpressions(row, writeFields)}
- """
+ """.trim
}
// TODO: if the nullability of array element is correct, we can use it to save null check.
@@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
}
- def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
- val exprEvals = expressions.map(e => e.gen(ctx))
+ def createCode(
+ ctx: CodeGenContext,
+ expressions: Seq[Expression],
+ useSubexprElimination: Boolean = false): GeneratedExpressionCode = {
+ val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)
val result = ctx.freshName("result")
@@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
+ // Reset the isLoaded flag for each row.
+ val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n")
+
val code =
s"""
$bufferHolder.reset();
+ $subexprReset
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
+
$result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
"""
GeneratedExpressionCode(code, "false", result)
@@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
+ def generate(
+ expressions: Seq[Expression],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ create(canonicalize(expressions), subexpressionEliminationEnabled)
+ }
+
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
- val ctx = newCodeGenContext()
+ create(expressions, false)
+ }
- val eval = createCode(ctx, expressions)
+ private def create(
+ expressions: Seq[Expression],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ val ctx = newCodeGenContext()
+ val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
val code = s"""
public Object generate($exprType[] exprs) {
@@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private $exprType[] expressions;
${declareMutableStates(ctx)}
+
${declareAddedFunctions(ctx)}
public SpecificUnsafeProjection($exprType[] expressions) {
@@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
- ${eval.code}
+ ${eval.code.trim}
return ${eval.value};
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 9ab5c299d0..f80bcfcb0b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -203,6 +203,10 @@ case class AttributeReference(
case _ => false
}
+ override def semanticHash(): Int = {
+ this.exprId.hashCode()
+ }
+
override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
new file mode 100644
index 0000000000..9de066e99d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.IntegerType
+
+class SubexpressionEliminationSuite extends SparkFunSuite {
+ test("Semantic equals and hash") {
+ val id = ExprId(1)
+ val a: AttributeReference = AttributeReference("name", IntegerType)()
+ val b1 = a.withName("name2").withExprId(id)
+ val b2 = a.withExprId(id)
+
+ assert(b1 != b2)
+ assert(a != b1)
+ assert(b1.semanticEquals(b2))
+ assert(!b1.semanticEquals(a))
+ assert(a.hashCode != b1.hashCode)
+ assert(b1.hashCode == b2.hashCode)
+ assert(b1.semanticHash() == b2.semanticHash())
+ }
+
+ test("Expression Equivalence - basic") {
+ val equivalence = new EquivalentExpressions
+ assert(equivalence.getAllEquivalentExprs.isEmpty)
+
+ val oneA = Literal(1)
+ val oneB = Literal(1)
+ val twoA = Literal(2)
+ var twoB = Literal(2)
+
+ assert(equivalence.getEquivalentExprs(oneA).isEmpty)
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+
+ // Add oneA and test if it is returned. Since it is a group of one, it does not.
+ assert(!equivalence.addExpr(oneA))
+ assert(equivalence.getEquivalentExprs(oneA).size == 1)
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+ assert(equivalence.addExpr((oneA)))
+ assert(equivalence.getEquivalentExprs(oneA).size == 2)
+
+ // Add B and make sure they can see each other.
+ assert(equivalence.addExpr(oneB))
+ // Use exists and reference equality because of how equals is defined.
+ assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB))
+ assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA))
+ assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
+ assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
+ assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+ assert(equivalence.getAllEquivalentExprs.size == 1)
+ assert(equivalence.getAllEquivalentExprs.head.size == 3)
+ assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
+ assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
+
+ val add1 = Add(oneA, oneB)
+ val add2 = Add(oneA, oneB)
+
+ equivalence.addExpr(add1)
+ equivalence.addExpr(add2)
+
+ assert(equivalence.getAllEquivalentExprs.size == 2)
+ assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
+ assert(equivalence.getEquivalentExprs(add2).size == 2)
+ assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
+ }
+
+ test("Expression Equivalence - Trees") {
+ val one = Literal(1)
+ val two = Literal(2)
+
+ val add = Add(one, two)
+ val abs = Abs(add)
+ val add2 = Add(add, add)
+
+ var equivalence = new EquivalentExpressions
+ equivalence.addExprTree(add, true)
+ equivalence.addExprTree(abs, true)
+ equivalence.addExprTree(add2, true)
+
+ // Should only have one equivalence for `one + two`
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1)
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4)
+
+ // Set up the expressions
+ // one * two,
+ // (one * two) * (one * two)
+ // sqrt( (one * two) * (one * two) )
+ // (one * two) + sqrt( (one * two) * (one * two) )
+ equivalence = new EquivalentExpressions
+ val mul = Multiply(one, two)
+ val mul2 = Multiply(mul, mul)
+ val sqrt = Sqrt(mul2)
+ val sum = Add(mul2, sqrt)
+ equivalence.addExprTree(mul, true)
+ equivalence.addExprTree(mul2, true)
+ equivalence.addExprTree(sqrt, true)
+ equivalence.addExprTree(sum, true)
+
+ // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3)
+ assert(equivalence.getEquivalentExprs(mul).size == 3)
+ assert(equivalence.getEquivalentExprs(mul2).size == 3)
+ assert(equivalence.getEquivalentExprs(sqrt).size == 2)
+ assert(equivalence.getEquivalentExprs(sum).size == 1)
+
+ // Some expressions inspired by TPCH-Q1
+ // sum(l_quantity) as sum_qty,
+ // sum(l_extendedprice) as sum_base_price,
+ // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
+ // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
+ // avg(l_extendedprice) as avg_price,
+ // avg(l_discount) as avg_disc
+ equivalence = new EquivalentExpressions
+ val quantity = Literal(1)
+ val price = Literal(1.1)
+ val discount = Literal(.24)
+ val tax = Literal(0.1)
+ equivalence.addExprTree(quantity, false)
+ equivalence.addExprTree(price, false)
+ equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
+ equivalence.addExprTree(
+ Multiply(
+ Multiply(price, Subtract(Literal(1), discount)),
+ Add(Literal(1), tax)), false)
+ equivalence.addExprTree(price, false)
+ equivalence.addExprTree(discount, false)
+ // quantity, price, discount and (price * (1 - discount))
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4)
+ }
+
+ test("Expression equivalence - non deterministic") {
+ val sum = Add(Rand(0), Rand(0))
+ val equivalence = new EquivalentExpressions
+ equivalence.addExpr(sum)
+ equivalence.addExpr(sum)
+ assert(equivalence.getAllEquivalentExprs.isEmpty)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b7314189b5..89e196c066 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -268,6 +268,11 @@ private[spark] object SQLConf {
doc = "When true, use the new optimized Tungsten physical execution backend.",
isPublic = false)
+ val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled",
+ defaultValue = Some(true), // use CODEGEN_ENABLED as default
+ doc = "When true, common subexpressions will be eliminated.",
+ isPublic = false)
+
val DIALECT = stringConf(
"spark.sql.dialect",
defaultValue = Some("sql"),
@@ -541,6 +546,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED))
+ private[spark] def subexpressionEliminationEnabled: Boolean =
+ getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled)
+
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 8bb293ae87..8650ac500b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
} else {
false
}
+ val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.conf.subexpressionEliminationEnabled
+ } else {
+ false
+ }
/**
* Whether the "prepare" method is called.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 145de0db9e..303d636164 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
- val project = UnsafeProjection.create(projectList, child.output)
+ val project = UnsafeProjection.create(projectList, child.output,
+ subexpressionEliminationEnabled)
iter.map { row =>
numRows += 1
project(row)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 441a0c6d0e..19e850a46f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1970,4 +1970,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
}
}
+
+ test("Common subexpression elimination") {
+ // select from a table to prevent constant folding.
+ val df = sql("SELECT a, b from testData2 limit 1")
+ checkAnswer(df, Row(1, 1))
+
+ checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+ checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+ // This does not work because the expressions get grouped like (a + a) + 1
+ checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+ checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+ // Identity udf that tracks the number of times it is called.
+ val countAcc = sparkContext.accumulator(0, "CallCount")
+ sqlContext.udf.register("testUdf", (x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+
+ // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
+ // is correct.
+ def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
+ countAcc.setValue(0)
+ checkAnswer(df, expectedResult)
+ assert(countAcc.value == expectedCount)
+ }
+
+ verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ // Would be nice if semantic equals for `+` understood commutative
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ // Try disabling it via configuration.
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ }
}