aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-20 10:02:40 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-20 10:02:40 -0800
commit8e4f894e986ccd943df9ddf55fc853eb0558886f (patch)
tree3c8f46759f12afbfe2a5253a75e136074f5febc1
parent753b1945115245800898959e3ab249a94a1935e9 (diff)
downloadspark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.gz
spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.bz2
spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.zip
[SPARK-12881] [SQL] subexpress elimination in mutable projection
Author: Davies Liu <davies@databricks.com> Closes #10814 from davies/mutable_subexpr.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
11 files changed, 80 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index f7162e420d..affd1bdb32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+
/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
@@ -67,7 +69,8 @@ class EquivalentExpressions {
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
- if (!skip && !addExpr(root)) {
+ // the children of CodegenFallback will not be used to generate code (call eval() instead)
+ if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 25cf210c4b..db17ba7c84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -100,8 +100,8 @@ abstract class Expression extends TreeNode[Expression] {
ExprCode(code, subExprState.isNull, subExprState.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
- val primitive = ctx.freshName("primitive")
- val ve = ExprCode("", isNull, primitive)
+ val value = ctx.freshName("value")
+ val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 683029ff14..2747c315ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -125,7 +125,7 @@ class CodegenContext {
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
// The collection of sub-exression result resetting methods that need to be called on each row.
- val subExprResetVariables = mutable.ArrayBuffer.empty[String]
+ val subexprFunctions = mutable.ArrayBuffer.empty[String]
def declareAddedFunctions(): String = {
addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
@@ -424,9 +424,9 @@ class CodegenContext {
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach(e => {
val expr = e.head
- val isNull = freshName("isNull")
- val value = freshName("value")
val fnName = freshName("evalExpr")
+ val isNull = s"${fnName}IsNull"
+ val value = s"${fnName}Value"
// Generate the code for this expression tree and wrap it in a function.
val code = expr.gen(this)
@@ -461,7 +461,7 @@ class CodegenContext {
addMutableState(javaType(expr.dataType), value,
s"$value = ${defaultValue(expr.dataType)};")
- subExprResetVariables += s"$fnName($INPUT_ROW);"
+ subexprFunctions += s"$fnName($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 59ef0f5836..d9fe76133c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -38,12 +38,29 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
+ def generate(
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute],
+ useSubexprElimination: Boolean): (() => MutableProjection) = {
+ create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination)
+ }
+
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
+ create(expressions, false)
+ }
+
+ private def create(
+ expressions: Seq[Expression],
+ useSubexprElimination: Boolean): (() => MutableProjection) = {
val ctx = newCodeGenContext()
- val projectionCodes = expressions.zipWithIndex.map {
- case (NoOp, _) => ""
- case (e, i) =>
- val evaluationCode = e.gen(ctx)
+ val (validExpr, index) = expressions.zipWithIndex.filter {
+ case (NoOp, _) => false
+ case _ => true
+ }.unzip
+ val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
+ val projectionCodes = exprVals.zip(index).map {
+ case (ev, i) =>
+ val e = expressions(i)
if (e.nullable) {
val isNull = s"isNull_$i"
val value = s"value_$i"
@@ -51,22 +68,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
ctx.addMutableState(ctx.javaType(e.dataType), value,
s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
- ${evaluationCode.code}
- this.$isNull = ${evaluationCode.isNull};
- this.$value = ${evaluationCode.value};
+ ${ev.code}
+ this.$isNull = ${ev.isNull};
+ this.$value = ${ev.value};
"""
} else {
val value = s"value_$i"
ctx.addMutableState(ctx.javaType(e.dataType), value,
s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
- ${evaluationCode.code}
- this.$value = ${evaluationCode.value};
+ ${ev.code}
+ this.$value = ${ev.value};
"""
}
}
- val updates = expressions.zipWithIndex.map {
- case (NoOp, _) => ""
+
+ // Evaluate all the the subexpressions.
+ val evalSubexpr = ctx.subexprFunctions.mkString("\n")
+
+ val updates = validExpr.zip(index).map {
case (e, i) =>
if (e.nullable) {
if (e.dataType.isInstanceOf[DecimalType]) {
@@ -128,6 +148,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
public java.lang.Object apply(java.lang.Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
+ $evalSubexpr
$allProjections
// copy all the results into MutableRow
$allUpdates
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 61e7469ee4..72bf39a039 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -294,13 +294,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
- // Reset the subexpression values for each row.
- val subexprReset = ctx.subExprResetVariables.mkString("\n")
+ // Evaluate all the subexpression.
+ val evalSubexpr = ctx.subexprFunctions.mkString("\n")
val code =
s"""
$bufferHolder.reset();
- $subexprReset
+ $evalSubexpr
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index a61297b2c0..43a3eb9dec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -154,4 +154,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
equivalence.addExpr(sum)
assert(equivalence.getAllEquivalentExprs.isEmpty)
}
+
+ test("Children of CodegenFallback") {
+ val one = Literal(1)
+ val two = Add(one, one)
+ val explode = Explode(two)
+ val add = Add(two, explode)
+
+ var equivalence = new EquivalentExpressions
+ equivalence.addExprTree(add, true)
+ // the `two` inside `explode` should not be added
+ assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 0)
+ assert(equivalence.getAllEquivalentExprs.filter(_.size == 1).size == 3) // add, two, explode
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 75101ea0fc..b19b772409 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[this] def isTesting: Boolean = sys.props.contains("spark.testing")
protected def newMutableProjection(
- expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = {
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute],
+ useSubexprElimination: Boolean = false): () => MutableProjection = {
log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
try {
- GenerateMutableProjection.generate(expressions, inputSchema)
+ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
} catch {
case e: Exception =>
if (isTesting) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 168b5ab031..26a7340f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -194,7 +194,11 @@ case class Window(
val functions = functionSeq.toArray
// Construct an aggregate processor if we need one.
- def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection)
+ def processor = AggregateProcessor(
+ functions,
+ ordinal,
+ child.output,
+ (expressions, schema) => newMutableProjection(expressions, schema))
// Create the factory
val factory = key match {
@@ -206,7 +210,7 @@ case class Window(
ordinal,
functions,
child.output,
- newMutableProjection,
+ (expressions, schema) => newMutableProjection(expressions, schema),
offset)
// Growing Frame.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index 1d56592c40..06a3991459 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -87,7 +87,8 @@ case class SortBasedAggregate(
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
numInputRows,
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a9cf04388d..8dcbab4c8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -94,7 +94,8 @@ case class TungstenAggregate(
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
iter,
testFallbackStartsAt,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d7f182352b..b159346bed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.sql.execution.{aggregate, SparkQl}
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
import org.apache.spark.sql.test.SQLTestData._
@@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+ val testUdf = functions.udf((x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ verifyCallCount(
+ df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
// Would be nice if semantic equals for `+` understood commutative
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)