aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
commit0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c (patch)
tree5c72eb22fb2ef033a6d08f989dcf4fa18d66a84f /sql/catalyst
parent0b7d4966ca7e02f351c4b92a74789cef4799fcb1 (diff)
downloadspark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.gz
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.bz2
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.zip
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust <michael@databricks.com> Closes #12048 from marmbrus/statefulAgg.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala73
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala3
8 files changed, 133 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d82ee3a205..05e2b9a447 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,6 +336,11 @@ class Analyzer(
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
+ }.transform {
+ // We are duplicating aggregates that are now computing a different value for each
+ // pivot value.
+ // TODO: Don't construct the physical container until after analysis.
+ case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
@@ -1153,11 +1158,11 @@ class Analyzer(
// Extract Windowed AggregateExpression
case we @ WindowExpression(
- AggregateExpression(function, mode, isDistinct),
+ ae @ AggregateExpression(function, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
- val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+ val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892e32..4880502398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -76,7 +76,7 @@ trait CheckAnalysis {
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
- case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+ case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1dd96..0420b4b538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
package object errors {
class TreeNodeException[TreeType <: TreeNode[_]](
- tree: TreeType, msg: String, cause: Throwable)
+ @transient val tree: TreeType,
+ msg: String,
+ cause: Throwable)
extends Exception(msg, cause) {
+ val treeString = tree.toString
+
// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
// external project dependencies for some reason.
def this(tree: TreeType, msg: String) = this(tree, msg, null)
override def getMessage: String = {
- val treeString = tree.toString
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064ac66..d31ccf9985 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}
+object AggregateExpression {
+ def apply(
+ aggregateFunction: AggregateFunction,
+ mode: AggregateMode,
+ isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(
+ aggregateFunction,
+ mode,
+ isDistinct,
+ NamedExpression.newExprId)
+ }
+}
+
/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable {
private[sql] case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean)
+ isDistinct: Boolean,
+ resultId: ExprId)
extends Expression
with Unevaluable {
+ lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+ AttributeReference(
+ aggregateFunction.toString,
+ aggregateFunction.dataType,
+ aggregateFunction.nullable)(exprId = resultId)
+ } else {
+ // This is a bit of a hack. Really we should not be constructing this container and reasoning
+ // about datatypes / aggregation mode until after we have finished analysis and made it to
+ // planning.
+ UnresolvedAttribute(aggregateFunction.toString)
+ }
+
+ // We compute the same thing regardless of our final result.
+ override lazy val canonicalized: Expression =
+ AggregateExpression(
+ aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+ mode,
+ isDistinct,
+ ExprId(0))
+
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 262582ca5d..2307122ea1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -329,7 +329,7 @@ case class PrettyAttribute(
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifier: Option[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
- override def nullable: Boolean = throw new UnsupportedOperationException
+ override def nullable: Boolean = true
}
object VirtualColumn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948ef1b..326933ec9e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
- AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+ ae.copy(aggregateFunction = Count(Literal(1)))
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
- case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c927077d0..28d2c445b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
@@ -216,3 +217,75 @@ object IntegerIndex {
case _ => None
}
}
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. Compared with a logical
+ * aggregation, the following transformations are performed:
+ * - Unnamed grouping expressions are named so that they can be referred to across phases of
+ * aggregation
+ * - Aggregations that appear multiple times are deduplicated.
+ * - The compution of the aggregations themselves is separated from the final result. For example,
+ * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
+ * computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+ // groupingExpressions, aggregateExpressions, resultExpressions, child
+ type ReturnType =
+ (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(a: Any): Option[ReturnType] = a match {
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case ae: AggregateExpression =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ ae.resultAttribute
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ Some((
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ child))
+
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index aa5d4330d3..7191936699 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
@@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
Alias(a.child, a.name)(exprId = ExprId(0))
+ case ae: AggregateExpression =>
+ ae.copy(resultId = ExprId(0))
}
}