aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-05-04 10:54:51 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-04 10:54:51 -0700
commitb85d21fb9dc3d498d9a10e065d254abde797efb6 (patch)
tree50cf5f6c1b9e6252e010eec1bd7063bcbd24cbec
parentd864c55cf8c92466336e796d0c98d83230e330af (diff)
downloadspark-b85d21fb9dc3d498d9a10e065d254abde797efb6.tar.gz
spark-b85d21fb9dc3d498d9a10e065d254abde797efb6.tar.bz2
spark-b85d21fb9dc3d498d9a10e065d254abde797efb6.zip
[SPARK-14951] [SQL] Support subexpression elimination in TungstenAggregate
## What changes were proposed in this pull request? We can support subexpression elimination in TungstenAggregate by using current `EquivalentExpressions` which is already used in subexpression elimination for expression codegen. However, in wholestage codegen, we can't wrap the common expression's codes in functions as before, we simply generate the code snippets for common expressions. These code snippets are inserted before the common expressions are actually used in generated java codes. For multiple `TypedAggregateExpression` used in aggregation operator, since their input type should be the same. So their `inputDeserializer` will be the same too. This patch can also reduce redundant input deserialization. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #12729 from viirya/subexpr-elimination-tungstenaggregate.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala38
4 files changed, 109 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index d0ad7a05a0..b8e2b67b2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -68,7 +68,10 @@ class EquivalentExpressions {
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
- def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
+ def addExprTree(
+ root: Expression,
+ ignoreLeaf: Boolean = true,
+ skipReferenceToExpressions: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
// There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
@@ -77,7 +80,7 @@ class EquivalentExpressions {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
- case _: ReferenceToExpressions => false
+ case _: ReferenceToExpressions if skipReferenceToExpressions => false
case _ => true
}
if (!skip && !addExpr(root) && shouldRecurse) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e4fa429b37..67f6719265 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -47,6 +47,25 @@ import org.apache.spark.util.Utils
case class ExprCode(var code: String, var isNull: String, var value: String)
/**
+ * State used for subexpression elimination.
+ *
+ * @param isNull A term that holds a boolean value representing whether the expression evaluated
+ * to null.
+ * @param value A term for a value of a common sub-expression. Not valid if `isNull`
+ * is set to `true`.
+ */
+case class SubExprEliminationState(isNull: String, value: String)
+
+/**
+ * Codes and common subexpressions mapping used for subexpression elimination.
+ *
+ * @param codes Strings representing the codes that evaluate common subexpressions.
+ * @param states Foreach expression that is participating in subexpression elimination,
+ * the state to use.
+ */
+case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
+
+/**
* A context for codegen, tracking a list of objects that could be passed into generated Java
* function.
*/
@@ -148,9 +167,6 @@ class CodegenContext {
*/
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
- // State used for subexpression elimination.
- case class SubExprEliminationState(isNull: String, value: String)
-
// Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
@@ -572,6 +588,58 @@ class CodegenContext {
}
/**
+ * Perform a function which generates a sequence of ExprCodes with a given mapping between
+ * expressions and common expressions, instead of using the mapping in current context.
+ */
+ def withSubExprEliminationExprs(
+ newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
+ f: => Seq[ExprCode]): Seq[ExprCode] = {
+ val oldsubExprEliminationExprs = subExprEliminationExprs
+ subExprEliminationExprs.clear
+ newSubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+
+ val genCodes = f
+
+ // Restore previous subExprEliminationExprs
+ subExprEliminationExprs.clear
+ oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+ genCodes
+ }
+
+ /**
+ * Checks and sets up the state and codegen for subexpression elimination. This finds the
+ * common subexpressions, generates the code snippets that evaluate those expressions and
+ * populates the mapping of common subexpressions to the generated code snippets. The generated
+ * code snippets will be returned and should be inserted into generated codes before these
+ * common subexpressions actually are used first time.
+ */
+ def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
+ // Create a clear EquivalentExpressions and SubExprEliminationState mapping
+ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+ val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+
+ // Add each expression tree and compute the common subexpressions.
+ expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
+
+ // Get all the expressions that appear at least twice and set up the state for subexpression
+ // elimination.
+ val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
+ val codes = commonExprs.map { e =>
+ val expr = e.head
+ val fnName = freshName("evalExpr")
+ val isNull = s"${fnName}IsNull"
+ val value = s"${fnName}Value"
+
+ // Generate the code for this expression tree.
+ val code = expr.genCode(this)
+ val state = SubExprEliminationState(code.isNull, code.value)
+ e.foreach(subExprEliminationExprs.put(_, state))
+ code.code.trim
+ }
+ SubExprCodes(codes, subExprEliminationExprs.toMap)
+ }
+
+ /**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index d0ba37ee13..d2dc80a7e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -244,8 +244,12 @@ case class TungstenAggregate(
}
}
ctx.currentVars = bufVars ++ input
- // TODO: support subexpression elimination
- val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx))
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
// aggregate buffer should be updated atomic
val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
@@ -255,6 +259,9 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
+ | // common sub-expressions
+ | $effectiveCodes
+ | // evaluate aggregate function
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
@@ -650,8 +657,12 @@ case class TungstenAggregate(
val updateRowInVectorizedHashMap: Option[String] = {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = vectorizedRowBuffer
- val vectorizedRowEvals =
- updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
@@ -659,6 +670,8 @@ case class TungstenAggregate(
}
Option(
s"""
+ |// common sub-expressions
+ |$effectiveCodes
|// evaluate aggregate function
|${evaluateVariables(vectorizedRowEvals)}
|// update vectorized row
@@ -701,13 +714,19 @@ case class TungstenAggregate(
val updateRowInUnsafeRowMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
- val unsafeRowBufferEvals =
- updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
+ |// common sub-expressions
+ |$effectiveCodes
|// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 535e64cb34..edca816cb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -31,31 +31,9 @@ object TypedAggregateExpression {
def apply[BUF : Encoder, OUT : Encoder](
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
val bufferEncoder = encoderFor[BUF]
- // We will insert the deserializer and function call expression at the bottom of each serializer
- // expression while executing `TypedAggregateExpression`, which means multiply serializer
- // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating,
- // here we always use one single serializer expression to serialize the buffer object into a
- // single-field row, no matter whether the encoder is flat or not. We also need to update the
- // deserializer to read in all fields from that single-field row.
- // TODO: remove this trick after we have better integration of subexpression elimination and
- // whole stage codegen.
- val bufferSerializer = if (bufferEncoder.flat) {
- bufferEncoder.namedExpressions.head
- } else {
- Alias(CreateStruct(bufferEncoder.serializer), "buffer")()
- }
-
- val bufferDeserializer = if (bufferEncoder.flat) {
- bufferEncoder.deserializer transformUp {
- case b: BoundReference => bufferSerializer.toAttribute
- }
- } else {
- bufferEncoder.deserializer transformUp {
- case UnresolvedAttribute(nameParts) =>
- assert(nameParts.length == 1)
- UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head))
- case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal)
- }
+ val bufferSerializer = bufferEncoder.namedExpressions
+ val bufferDeserializer = bufferEncoder.deserializer.transform {
+ case b: BoundReference => bufferSerializer(b.ordinal).toAttribute
}
val outputEncoder = encoderFor[OUT]
@@ -82,7 +60,7 @@ object TypedAggregateExpression {
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
inputDeserializer: Option[Expression],
- bufferSerializer: NamedExpression,
+ bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
outputSerializer: Seq[Expression],
outputExternalType: DataType,
@@ -106,11 +84,11 @@ case class TypedAggregateExpression(
private def bufferExternalType = bufferDeserializer.dataType
override lazy val aggBufferAttributes: Seq[AttributeReference] =
- bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil
+ bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])
override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
- ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil
+ bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
}
override lazy val updateExpressions: Seq[Expression] = {
@@ -120,7 +98,7 @@ case class TypedAggregateExpression(
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)
- ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil
+ bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
}
override lazy val mergeExpressions: Seq[Expression] = {
@@ -136,7 +114,7 @@ case class TypedAggregateExpression(
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)
- ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil
+ bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
}
override lazy val evaluateExpression: Expression = {