diff options
author | Michael Armbrust <michael@databricks.com> | 2016-04-01 15:15:16 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-04-01 15:15:16 -0700 |
commit | 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c (patch) | |
tree | 5c72eb22fb2ef033a6d08f989dcf4fa18d66a84f /sql/catalyst | |
parent | 0b7d4966ca7e02f351c4b92a74789cef4799fcb1 (diff) | |
download | spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.gz spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.bz2 spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.zip |
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression:
- Partial Aggregation
- Shuffle
- Partial Merge (now there is at most 1 tuple per group)
- StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
- Partial Merge (now there is at most 1 tuple per group)
- StateStoreSave (saves the tuple for the next batch)
- Complete (output the current result of the aggregation)
The following refactoring was also performed to allow us to plug into existing code:
- The get/put implementation is taken from #12013
- The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation`
- The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup.
- Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case.
- The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes.
Author: Michael Armbrust <michael@databricks.com>
Closes #12048 from marmbrus/statefulAgg.
Diffstat (limited to 'sql/catalyst')
8 files changed, 133 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d82ee3a205..05e2b9a447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,11 @@ class Analyzer( Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( @@ -1153,11 +1158,11 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - AggregateExpression(function, mode, isDistinct), + ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val newAgg = AggregateExpression(newFunction, mode, isDistinct) + val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d1e892e32..4880502398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -76,7 +76,7 @@ trait CheckAnalysis { case g: GroupingID => failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0d44d1dd96..0420b4b538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff3064ac66..d31ccf9985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. @@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable { private[sql] case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) + isDistinct: Boolean, + resultId: ExprId) extends Expression with Unevaluable { + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) + override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 262582ca5d..2307122ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -329,7 +329,7 @@ case class PrettyAttribute( override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException + override def nullable: Boolean = true } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948ef1b..326933ec9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. - AggregateExpression(Count(Literal(1)), mode, isDistinct = false) + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c927077d0..28d2c445b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -216,3 +217,75 @@ object IntegerIndex { case _ => None } } + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index aa5d4330d3..7191936699 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) } } |