aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-01 15:15:16 -0700
commit0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c (patch)
tree5c72eb22fb2ef033a6d08f989dcf4fa18d66a84f /sql
parent0b7d4966ca7e02f351c4b92a74789cef4799fcb1 (diff)
downloadspark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.gz
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.tar.bz2
spark-0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c.zip
[SPARK-14255][SQL] Streaming Aggregation
This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust <michael@databricks.com> Closes #12048 from marmbrus/statefulAgg.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala73
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala92
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala121
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala152
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala132
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala5
33 files changed, 827 insertions, 305 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d82ee3a205..05e2b9a447 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,6 +336,11 @@ class Analyzer(
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
+ }.transform {
+ // We are duplicating aggregates that are now computing a different value for each
+ // pivot value.
+ // TODO: Don't construct the physical container until after analysis.
+ case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
@@ -1153,11 +1158,11 @@ class Analyzer(
// Extract Windowed AggregateExpression
case we @ WindowExpression(
- AggregateExpression(function, mode, isDistinct),
+ ae @ AggregateExpression(function, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
- val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+ val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892e32..4880502398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -76,7 +76,7 @@ trait CheckAnalysis {
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
- case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+ case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1dd96..0420b4b538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
package object errors {
class TreeNodeException[TreeType <: TreeNode[_]](
- tree: TreeType, msg: String, cause: Throwable)
+ @transient val tree: TreeType,
+ msg: String,
+ cause: Throwable)
extends Exception(msg, cause) {
+ val treeString = tree.toString
+
// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
// external project dependencies for some reason.
def this(tree: TreeType, msg: String) = this(tree, msg, null)
override def getMessage: String = {
- val treeString = tree.toString
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064ac66..d31ccf9985 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}
+object AggregateExpression {
+ def apply(
+ aggregateFunction: AggregateFunction,
+ mode: AggregateMode,
+ isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(
+ aggregateFunction,
+ mode,
+ isDistinct,
+ NamedExpression.newExprId)
+ }
+}
+
/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable {
private[sql] case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean)
+ isDistinct: Boolean,
+ resultId: ExprId)
extends Expression
with Unevaluable {
+ lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+ AttributeReference(
+ aggregateFunction.toString,
+ aggregateFunction.dataType,
+ aggregateFunction.nullable)(exprId = resultId)
+ } else {
+ // This is a bit of a hack. Really we should not be constructing this container and reasoning
+ // about datatypes / aggregation mode until after we have finished analysis and made it to
+ // planning.
+ UnresolvedAttribute(aggregateFunction.toString)
+ }
+
+ // We compute the same thing regardless of our final result.
+ override lazy val canonicalized: Expression =
+ AggregateExpression(
+ aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+ mode,
+ isDistinct,
+ ExprId(0))
+
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 262582ca5d..2307122ea1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -329,7 +329,7 @@ case class PrettyAttribute(
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifier: Option[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
- override def nullable: Boolean = throw new UnsupportedOperationException
+ override def nullable: Boolean = true
}
object VirtualColumn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948ef1b..326933ec9e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
- AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+ ae.copy(aggregateFunction = Count(Literal(1)))
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
- case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c927077d0..28d2c445b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
@@ -216,3 +217,75 @@ object IntegerIndex {
case _ => None
}
}
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. Compared with a logical
+ * aggregation, the following transformations are performed:
+ * - Unnamed grouping expressions are named so that they can be referred to across phases of
+ * aggregation
+ * - Aggregations that appear multiple times are deduplicated.
+ * - The compution of the aggregations themselves is separated from the final result. For example,
+ * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
+ * computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+ // groupingExpressions, aggregateExpressions, resultExpressions, child
+ type ReturnType =
+ (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(a: Any): Option[ReturnType] = a match {
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case ae: AggregateExpression =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ ae.resultAttribute
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ Some((
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ child))
+
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index aa5d4330d3..7191936699 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
@@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
Alias(a.child, a.name)(exprId = ExprId(0))
+ case ae: AggregateExpression =>
+ ae.copy(resultId = ExprId(0))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 912b84abc1..4843553211 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
+ // TODO: Move the planner an optimizer into here from SessionState.
+ protected def planner = sqlContext.sessionState.planner
+
def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
case e: AnalysisException =>
val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
@@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
+ planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
+ lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
+ /**
+ * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
+ * row format conversions as needed.
+ */
+ protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
+ preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+ }
+
+ /** A sequence of rules that will be applied in order to the physical plan before execution. */
+ protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+ PlanSubqueries(sqlContext),
+ EnsureRequirements(sqlContext.conf),
+ CollapseCodegenStages(sqlContext.conf),
+ ReuseExchange(sqlContext.conf))
+
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 010ed7f500..b1b3d4ac81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan {
override def producedAttributes: AttributeSet = outputSet
}
+object UnaryNode {
+ def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
+ case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
+ case _ => None
+ }
+}
+
private[sql] trait UnaryNode extends SparkPlan {
def child: SparkPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 9da2c74c62..ac8072f3ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
- val experimentalMethods: ExperimentalMethods)
+ val extraStrategies: Seq[Strategy])
extends SparkStrategies {
def numPartitions: Int = conf.numShufflePartitions
def strategies: Seq[Strategy] =
- experimentalMethods.extraStrategies ++ (
+ extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7a2e2b7382..5bcc172ca7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -204,28 +203,32 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
+ * Used to plan aggregation queries that are computed incrementally as part of a
+ * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner
+ * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
+ */
+ object StatefulAggregationStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalAggregation(
+ namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
+
+ aggregate.Utils.planStreamingAggregation(
+ namedGroupingExpressions,
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ planLater(child))
+
+ case _ => Nil
+ }
+ }
+
+ /**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
- // A single aggregate expression might appear multiple times in resultExpressions.
- // In order to avoid evaluating an individual aggregate function multiple times, we'll
- // build a set of the distinct aggregate expressions and build a function which can
- // be used to re-write expressions so that they reference the single copy of the
- // aggregate function which actually gets computed.
- val aggregateExpressions = resultExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression => agg
- }
- }.distinct
- // For those distinct aggregate expressions, we create a map from the
- // aggregate function to the corresponding attribute of the function.
- val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
- val aggregateFunction = agg.aggregateFunction
- val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
- (aggregateFunction, agg.isDistinct) -> attribute
- }.toMap
+ case PhysicalAggregation(
+ groupingExpressions, aggregateExpressions, resultExpressions, child) =>
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
@@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our MultipleDistinctRewriter should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
- "Spark user mailing list.")
- }
-
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
-
- // The original `resultExpressions` are a set of expressions which may reference
- // aggregate expressions, grouping column values, and constants. When aggregate operator
- // emits output rows, we will use `resultExpressions` to generate an output projection
- // which takes the grouping columns and final aggregate result buffer as input.
- // Thus, we must re-write the result expressions so that their attributes match up with
- // the attributes of the final result projection's input row:
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case AggregateExpression(aggregateFunction, _, isDistinct) =>
- // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
- // so replace each aggregate expression by its corresponding attribute in the set:
- aggregateFunctionToAttribute(aggregateFunction, isDistinct)
- case expression =>
- // Since we're using `namedGroupingAttributes` to extract the grouping key
- // columns, we need to replace grouping key expressions with their corresponding
- // attributes. We do not rely on the equality check at here since attributes may
- // differ cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+ "Spark user mailing list.")
}
val aggregateOperator =
@@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.Utils.planAggregateWithoutPartial(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 270c09aff3..7acf020b28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -177,7 +177,7 @@ case class Window(
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
- case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f)
+ case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
case f => sys.error(s"Unsupported window function: $f")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 213bca907b..ce504e20e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -242,9 +242,9 @@ class TungstenAggregationIterator(
// Basically the value of the KVIterator returned by externalSorter
// will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
val newExpressions = aggregateExpressions.map {
- case agg @ AggregateExpression(_, Partial, _) =>
+ case agg @ AggregateExpression(_, Partial, _, _) =>
agg.copy(mode = PartialMerge)
- case agg @ AggregateExpression(_, Complete, _) =>
+ case agg @ AggregateExpression(_, Complete, _, _) =>
agg.copy(mode = Final)
case other => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1e113ccd4e..4682949fa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave}
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -29,15 +30,11 @@ object Utils {
def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
- val completeAggregateAttributes = completeAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
-
+ val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingExpressions),
groupingExpressions = groupingExpressions,
@@ -83,7 +80,6 @@ object Utils {
def planAggregateWithoutDistinct(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use TungstenAggregate.
@@ -111,9 +107,7 @@ object Utils {
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
val finalAggregate = createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
@@ -131,7 +125,6 @@ object Utils {
groupingExpressions: Seq[NamedExpression],
functionsWithDistinct: Seq[AggregateExpression],
functionsWithoutDistinct: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
@@ -151,9 +144,7 @@ object Utils {
// 1. Create an Aggregate Operator for partial aggregations.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
- val aggregateAttributes = aggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
@@ -169,9 +160,7 @@ object Utils {
// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
- val aggregateAttributes = aggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes ++ distinctAttributes),
@@ -190,7 +179,7 @@ object Utils {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression(aggregateFunction, mode, true) =>
+ case agg @ AggregateExpression(aggregateFunction, mode, true, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
}
@@ -199,9 +188,7 @@ object Utils {
val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val mergeAggregateAttributes = mergeAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute)
val (distinctAggregateExpressions, distinctAggregateAttributes) =
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
// We rewrite the aggregate function to a non-distinct aggregation because
@@ -211,7 +198,7 @@ object Utils {
val expr = AggregateExpression(func, Partial, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
- val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
}.unzip
@@ -232,9 +219,7 @@ object Utils {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
val (distinctAggregateExpressions, distinctAggregateAttributes) =
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
@@ -245,7 +230,7 @@ object Utils {
val expr = AggregateExpression(func, Final, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
- val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
}.unzip
@@ -261,4 +246,90 @@ object Utils {
finalAndCompleteAggregate :: Nil
}
+
+ /**
+ * Plans a streaming aggregation using the following progression:
+ * - Partial Aggregation
+ * - Shuffle
+ * - Partial Merge (now there is at most 1 tuple per group)
+ * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
+ * - PartialMerge (now there is at most 1 tuple per group)
+ * - StateStoreSave (saves the tuple for the next batch)
+ * - Complete (output the current result of the aggregation)
+ */
+ def planStreamingAggregation(
+ groupingExpressions: Seq[NamedExpression],
+ functionsWithoutDistinct: Seq[AggregateExpression],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+
+ val partialAggregate: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ createAggregate(
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = child)
+ }
+
+ val partialMerged1: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
+ requiredChildDistributionExpressions =
+ Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = partialAggregate)
+ }
+
+ val restored = StateStoreRestore(groupingAttributes, None, partialMerged1)
+
+ val partialMerged2: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
+ requiredChildDistributionExpressions =
+ Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = restored)
+ }
+
+ val saved = StateStoreSave(groupingAttributes, None, partialMerged2)
+
+ val finalAndCompleteAggregate: SparkPlan = {
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
+
+ createAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = finalAggregateExpressions,
+ aggregateAttributes = finalAggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = resultExpressions,
+ child = saved)
+ }
+
+ finalAndCompleteAggregate :: Nil
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
new file mode 100644
index 0000000000..aaced49dd1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -0,0 +1,72 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode}
+
+/**
+ * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]]
+ * plan incrementally. Possibly preserving state in between each execution.
+ */
+class IncrementalExecution(
+ ctx: SQLContext,
+ logicalPlan: LogicalPlan,
+ checkpointLocation: String,
+ currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) {
+
+ // TODO: make this always part of planning.
+ val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil
+
+ // Modified planner with stateful operations.
+ override def planner: SparkPlanner =
+ new SparkPlanner(
+ sqlContext.sparkContext,
+ sqlContext.conf,
+ stateStrategy)
+
+ /**
+ * Records the current id for a given stateful operator in the query plan as the `state`
+ * preperation walks the query plan.
+ */
+ private var operatorId = 0
+
+ /** Locates save/restore pairs surrounding aggregation. */
+ val state = new Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = plan transform {
+ case StateStoreSave(keys, None,
+ UnaryNode(agg,
+ StateStoreRestore(keys2, None, child))) =>
+ val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1)
+ operatorId += 1
+
+ StateStoreSave(
+ keys,
+ Some(stateId),
+ agg.withNewChildren(
+ StateStoreRestore(
+ keys,
+ Some(stateId),
+ child) :: Nil))
+ }
+ }
+
+ override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
new file mode 100644
index 0000000000..595774761c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.execution.SparkPlan
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+ checkpointLocation: String,
+ operatorId: Long,
+ batchId: Long)
+
+/**
+ * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+ def stateId: Option[OperatorStateId]
+
+ protected def getStateId: OperatorStateId = attachTree(this) {
+ stateId.getOrElse {
+ throw new IllegalStateException("State location not present for execution")
+ }
+ }
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestore(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId],
+ child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ new StateStoreConf(sqlContext.conf),
+ Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+ val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+ iter.flatMap { row =>
+ val key = getKey(row)
+ val savedState = store.get(key)
+ row +: savedState.toSeq
+ }
+ }
+ }
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
+ */
+case class StateStoreSave(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId],
+ child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ new StateStoreConf(sqlContext.conf),
+ Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+ new Iterator[InternalRow] {
+ private[this] val baseIterator = iter
+ private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+
+ override def hasNext: Boolean = {
+ if (!baseIterator.hasNext) {
+ store.commit()
+ false
+ } else {
+ true
+ }
+ }
+
+ override def next(): InternalRow = {
+ val row = baseIterator.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key.copy(), row.copy())
+ row
+ }
+ }
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index c4e410d92c..511e30c70c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
@@ -272,6 +273,8 @@ class StreamExecution(
private def runBatch(): Unit = {
val startTime = System.nanoTime()
+ // TODO: Move this to IncrementalExecution.
+
// Request unprocessed data from all sources.
val newData = availableOffsets.flatMap {
case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) =>
@@ -305,13 +308,14 @@ class StreamExecution(
}
val optimizerStart = System.nanoTime()
-
- lastExecution = new QueryExecution(sqlContext, newPlan)
- val executedPlan = lastExecution.executedPlan
+ lastExecution =
+ new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId)
+ lastExecution.executedPlan
val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
logDebug(s"Optimized batch in ${optimizerTime}ms")
- val nextBatch = Dataset.ofRows(sqlContext, newPlan)
+ val nextBatch =
+ new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema))
sink.addBatch(currentBatchId - 1, nextBatch)
awaitBatchLock.synchronized {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 0f91e59e04..7d97f81b0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySink(schema: StructType) extends Sink with Logging {
+class MemorySink(val schema: StructType) extends Sink with Logging {
/** An order list of batches that have been written to this [[Sink]]. */
private val batches = new ArrayBuffer[Array[Row]]()
@@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging {
batches.flatten
}
+ def lastBatch: Seq[Row] = batches.last
+
def toDebugString: String = synchronized {
batches.zipWithIndex.map { case (b, i) =>
val dataStr = try b.mkString(" ") catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index ee015baf3f..998eb82de1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider(
trait STATE
case object UPDATING extends STATE
case object COMMITTED extends STATE
- case object CANCELLED extends STATE
+ case object ABORTED extends STATE
private val newVersion = version + 1
private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
@@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider(
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
- /**
- * Update the value of a key using the value generated by the update function.
- * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
- * versions of the store data.
- */
- override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
- verify(state == UPDATING, "Cannot update after already committed or cancelled")
- val oldValueOption = Option(mapToUpdate.get(key))
- val value = updateFunc(oldValueOption)
+ override def get(key: UnsafeRow): Option[UnsafeRow] = {
+ Option(mapToUpdate.get(key))
+ }
+
+ override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or cancelled")
+
+ val isNewKey = !mapToUpdate.containsKey(key)
mapToUpdate.put(key, value)
Option(allUpdates.get(key)) match {
@@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider(
case None =>
// There was no prior update, so mark this as added or updated according to its presence
// in previous version.
- val update =
- if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
+ val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
allUpdates.put(key, update)
}
writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
@@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider(
/** Commit all the updates that have been made to the store, and return the new version. */
override def commit(): Long = {
- verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
+ verify(state == UPDATING, "Cannot commit after already committed or cancelled")
try {
finalizeDeltaFile(tempDeltaFileStream)
@@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider(
}
/** Cancel all the updates made on this store. This store will not be usable any more. */
- override def cancel(): Unit = {
- state = CANCELLED
+ override def abort(): Unit = {
+ state = ABORTED
if (tempDeltaFileStream != null) {
tempDeltaFileStream.close()
}
@@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider(
}
/**
- * Get an iterator of all the store data. This can be called only after committing the
- * updates.
+ * Get an iterator of all the store data.
+ * This can be called only after committing all the updates made in the current thread.
*/
override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
@@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider(
/**
* Get an iterator of all the updates made to the store in the current version.
- * This can be called only after committing the updates.
+ * This can be called only after committing all the updates made in the current thread.
*/
override def updates(): Iterator[StoreUpdate] = {
verify(state == COMMITTED, "Cannot get iterator of updates before committing")
@@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider(
/**
* Whether all updates have been committed
*/
- override def hasCommitted: Boolean = {
+ override private[state] def hasCommitted: Boolean = {
state == COMMITTED
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index ca5c864d9e..d60e6185ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -47,12 +47,11 @@ trait StateStore {
/** Version of the data in this store before committing updates. */
def version: Long
- /**
- * Update the value of a key using the value generated by the update function.
- * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
- * versions of the store data.
- */
- def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+ /** Get the current value of a key. */
+ def get(key: UnsafeRow): Option[UnsafeRow]
+
+ /** Put a new value for a key. */
+ def put(key: UnsafeRow, value: UnsafeRow)
/**
* Remove keys that match the following condition.
@@ -65,24 +64,24 @@ trait StateStore {
def commit(): Long
/** Cancel all the updates that have been made to the store. */
- def cancel(): Unit
+ def abort(): Unit
/**
* Iterator of store data after a set of updates have been committed.
- * This can be called only after commitUpdates() has been called in the current thread.
+ * This can be called only after committing all the updates made in the current thread.
*/
def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
/**
* Iterator of the updates that have been committed.
- * This can be called only after commitUpdates() has been called in the current thread.
+ * This can be called only after committing all the updates made in the current thread.
*/
def updates(): Iterator[StoreUpdate]
/**
* Whether all updates have been committed
*/
- def hasCommitted: Boolean
+ private[state] def hasCommitted: Boolean
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index cca22a0af8..f0f1f3a1a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.sql.internal.SQLConf
/** A class that contains configuration parameters for [[StateStore]]s. */
-private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
def this() = this(new SQLConf)
@@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend
val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
}
-private[state] object StateStoreConf {
+private[streaming] object StateStoreConf {
val empty = new StateStoreConf()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 3318660895..df3d82c113 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
var store: StateStore = null
-
- Utils.tryWithSafeFinally {
- val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
- store = StateStore.get(
- storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
- val inputIter = dataRDD.iterator(partition, ctxt)
- val outputIter = storeUpdateFunction(store, inputIter)
- assert(store.hasCommitted)
- outputIter
- } {
- if (store != null) store.cancel()
- }
+ val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+ store = StateStore.get(
+ storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+ val inputIter = dataRDD.iterator(partition, ctxt)
+ storeUpdateFunction(store, inputIter)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index b249e37921..9b6d0918e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -28,37 +28,36 @@ package object state {
implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
/** Map each partition of a RDD along with data in a [[StateStore]]. */
- def mapPartitionWithStateStore[U: ClassTag](
- storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ def mapPartitionsWithStateStore[U: ClassTag](
+ sqlContext: SQLContext,
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
- valueSchema: StructType
- )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+ valueSchema: StructType)(
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
- mapPartitionWithStateStore(
- storeUpdateFunction,
+ mapPartitionsWithStateStore(
checkpointLocation,
operatorId,
storeVersion,
keySchema,
valueSchema,
new StateStoreConf(sqlContext.conf),
- Some(sqlContext.streams.stateStoreCoordinator))
+ Some(sqlContext.streams.stateStoreCoordinator))(
+ storeUpdateFunction)
}
/** Map each partition of a RDD along with data in a [[StateStore]]. */
- private[state] def mapPartitionWithStateStore[U: ClassTag](
- storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
storeConf: StateStoreConf,
- storeCoordinator: Option[StateStoreCoordinatorRef]
- ): StateStoreRDD[T, U] = {
+ storeCoordinator: Option[StateStoreCoordinatorRef])(
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
new StateStoreRDD(
dataRDD,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 0d580703f5..4b3091ba22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.DataType
/**
@@ -60,14 +60,13 @@ case class ScalarSubquery(
}
/**
- * Convert the subquery from logical plan into executed plan.
+ * Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
-case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
+case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
- val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
- val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
+ val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 844f3051fa..9cb356f1ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable {
implicit bEncoder: Encoder[B],
cEncoder: Encoder[O]): TypedColumn[I, O] = {
val expr =
- new AggregateExpression(
+ AggregateExpression(
TypedAggregateExpression(this),
Complete,
- false)
+ isDistinct = false)
new TypedColumn[I, O](expr, encoderFor[O])
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index f7fdfacd31..cd3d254d1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Planner that converts optimized logical plans to physical plans.
*/
- lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods)
-
- /**
- * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
- * row format conversions as needed.
- */
- lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
- override val batches: Seq[Batch] = Seq(
- Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
- Batch("Add exchange", Once, EnsureRequirements(conf)),
- Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
- Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
- )
- }
+ def planner: SparkPlanner =
+ new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index b5be7ef47e..550c3c6f9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
val toExternalRow = RowEncoder(encoder.schema)
- CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
}
- def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false)
}
- case class CheckAnswerRows(expectedAnswer: Seq[Row])
+ /**
+ * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
+ * This operation automatically blocks until all added data has been processed.
+ */
+ object CheckLastBatch {
+ def apply[A : Encoder](data: A*): CheckAnswerRows = {
+ val encoder = encoderFor[A]
+ val toExternalRow = RowEncoder(encoder.schema)
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
+ }
+
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true)
+ }
+
+ case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean)
extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
+ override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
+ private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}
/** Stops the stream. It must currently be running. */
@@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts {
""".stripMargin
def verify(condition: => Boolean, message: String): Unit = {
- try {
- Assertions.assert(condition)
- } catch {
- case NonFatal(e) =>
- failTest(message, e)
+ if (!condition) {
+ failTest(message)
}
}
@@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts {
case a: AddData =>
awaiting.put(a.source, a.addData())
- case CheckAnswerRows(expectedAnswer) =>
+ case CheckAnswerRows(expectedAnswer, lastOnly) =>
verify(currentStream != null, "stream not running")
// Block until all data added has been processed
@@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts {
}
}
- val allData = try sink.allData catch {
+ val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch {
case e: Exception =>
failTest("Exception while getting data from sink", e)
}
- QueryTest.sameRows(expectedAnswer, allData).foreach {
+ QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index ed0d3f56e5..38318740a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,10 +231,8 @@ object SparkPlanTest {
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
- outputPlan transform {
+ val execution = new QueryExecution(sqlContext, null) {
+ override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
@@ -243,8 +241,8 @@ object SparkPlanTest {
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
- )
- resolvedPlan.executeCollectPublic().toSeq
+ }
+ execution.executedPlan.executeCollectPublic().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 85db05157c..6be94eb24f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CompletionIterator, Utils}
class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
@@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
test("versioning and immutability") {
- quietly {
- withSpark(new SparkContext(sparkConf)) { sc =>
- implicit val sqlContet = new SQLContext(sc)
- val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
- val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ val sqlContext = new SQLContext(sc)
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
+ val rdd1 =
+ makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ increment)
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+ // Generate next version of stores
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+ // Make sure the previous RDD still has the same data.
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ }
+ }
- // Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
- assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+ test("recovering from files") {
+ val opId = 0
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+ def makeStoreRDD(
+ sc: SparkContext,
+ seq: Seq[String],
+ storeVersion: Int): RDD[(String, Int)] = {
+ implicit val sqlContext = new SQLContext(sc)
+ makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+ }
- // Make sure the previous RDD still has the same data.
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ // Generate RDDs and state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ for (i <- 1 to 20) {
+ require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}
+
+ // With a new context, try using the earlier state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+ }
}
- test("recovering from files") {
- quietly {
- val opId = 0
+ test("usage with iterators - only gets and only puts") {
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
- def makeStoreRDD(
- sc: SparkContext,
- seq: Seq[String],
- storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = new SQLContext(sc)
- makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion, keySchema, valueSchema)
+ // Returns an iterator of the incremented value made into the store
+ def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
+ val resIterator = iter.map { s =>
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ val newValue = oldValue + 1
+ store.put(key, intToRow(newValue))
+ (s, newValue)
+ }
+ CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
+ store.commit()
+ })
}
- // Generate RDDs and state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- for (i <- 1 to 20) {
- require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+ def iteratorOfGets(
+ store: StateStore,
+ iter: Iterator[String]): Iterator[(String, Option[Int])] = {
+ iter.map { s =>
+ val key = stringToRow(s)
+ val value = store.get(key).map(rowToInt)
+ (s, value)
}
}
- // With a new context, try using the earlier state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
- }
+ val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
+
+ val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+ assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
+
+ val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
}
}
@@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
- val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
require(rdd.partitions.length === 2)
assert(
@@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("distributed test") {
quietly {
withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
- implicit val sqlContet = new SQLContext(sc)
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
private val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ store.put(key, intToRow(oldValue + 1))
}
store.commit()
store.iterator().map(rowsToStringInt)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 22b2f4f75d..0e5936d53f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
StateStore.stop()
}
- test("update, remove, commit, and all data iterator") {
+ test("get, put, remove, commit, and all data iterator") {
val provider = newStoreProvider()
// Verify state before starting a new set of updates
@@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
}
// Verify state after updating
- update(store, "a", 1)
+ put(store, "a", 1)
intercept[IllegalStateException] {
store.iterator()
}
@@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
assert(provider.latestIterator().isEmpty)
// Make updates, commit and then verify state
- update(store, "b", 2)
- update(store, "aa", 3)
+ put(store, "b", 2)
+ put(store, "aa", 3)
remove(store, _.startsWith("a"))
assert(store.commit() === 1)
@@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val reloadedProvider = new HDFSBackedStateStoreProvider(
store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
val reloadedStore = reloadedProvider.getStore(1)
- update(reloadedStore, "c", 4)
+ put(reloadedStore, "c", 4)
assert(reloadedStore.commit() === 2)
assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
@@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("updates iterator with all combos of updates and removes") {
val provider = newStoreProvider()
var currentVersion: Int = 0
+
def withStore(body: StateStore => Unit): Unit = {
val store = provider.getStore(currentVersion)
body(store)
@@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// New data should be seen in updates as value added, even if they had multiple updates
withStore { store =>
- update(store, "a", 1)
- update(store, "aa", 1)
- update(store, "aa", 2)
+ put(store, "a", 1)
+ put(store, "aa", 1)
+ put(store, "aa", 2)
store.commit()
assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
@@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Multiple updates to same key should be collapsed in the updates as a single value update
// Keys that have not been updated should not appear in the updates
withStore { store =>
- update(store, "a", 4)
- update(store, "a", 6)
+ put(store, "a", 4)
+ put(store, "a", 6)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
@@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Keys added, updated and finally removed before commit should not appear in updates
withStore { store =>
- update(store, "b", 4) // Added, finally removed
- update(store, "bb", 5) // Added, updated, finally removed
- update(store, "bb", 6)
+ put(store, "b", 4) // Added, finally removed
+ put(store, "bb", 5) // Added, updated, finally removed
+ put(store, "bb", 6)
remove(store, _.startsWith("b"))
store.commit()
assert(updatesToSet(store.updates()) === Set.empty)
@@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Removed, but re-added data should be seen in updates as a value update
withStore { store =>
remove(store, _.startsWith("a"))
- update(store, "a", 10)
+ put(store, "a", 10)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
assert(rowsToSet(store.iterator()) === Set("a" -> 10))
@@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("cancel") {
val provider = newStoreProvider()
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
store.commit()
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
// cancelUpdates should not change the data in the files
val store1 = provider.getStore(1)
- update(store1, "b", 1)
- store1.cancel()
+ put(store1, "b", 1)
+ store1.abort()
assert(getDataFromFiles(provider) === Set("a" -> 1))
}
@@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Prepare some data in the stoer
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
assert(store.commit() === 1)
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
@@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Update store version with some data
val store1 = provider.getStore(1)
- update(store1, "b", 1)
+ put(store1, "b", 1)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
// Overwrite the version with other data
val store2 = provider.getStore(1)
- update(store2, "c", 1)
+ put(store2, "c", 1)
assert(store2.commit() === 2)
assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
@@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
def updateVersionTo(targetVersion: Int): Unit = {
for (i <- currentVersion + 1 to targetVersion) {
val store = provider.getStore(currentVersion)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
currentVersion += 1
}
@@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val provider = newStoreProvider(minDeltasForSnapshot = 5)
for (i <- 1 to 6) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Increase version of the store
val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
assert(store0.version === 0)
- update(store0, "a", 1)
+ put(store0, "a", 1)
store0.commit()
assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
@@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeId))
- update(store1, "a", 2)
+ put(store1, "a", 2)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
}
@@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = StateStore.get(
storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
}
eventually(timeout(10 seconds)) {
@@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
store.remove(row => condition(rowToString(row)))
}
- private def update(store: StateStore, key: String, value: Int): Unit = {
- store.update(stringToRow(key), _ => intToRow(value))
+ private def put(store: StateStore, key: String, value: Int): Unit = {
+ store.put(stringToRow(key), intToRow(value))
+ }
+
+ private def get(store: StateStore, key: String): Option[Int] = {
+ store.get(stringToRow(key)).map(rowToInt)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
new file mode 100644
index 0000000000..b63ce89d18
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+object FailureSinglton {
+ var firstTime = true
+}
+
+class StreamingAggregationSuite extends StreamTest with SharedSQLContext {
+
+ import testImplicits._
+
+ test("simple count") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream,
+ StartStream,
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4))
+ )
+ }
+
+ test("multiple keys") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value", $"value" + 1)
+ .agg(count("*"))
+ .as[(Int, Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 1), (2, 3, 1)),
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 2), (2, 3, 2))
+ )
+ }
+
+ test("multiple aggregations") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*") as 'count)
+ .groupBy($"value" % 2)
+ .agg(sum($"count"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2, 3, 4),
+ CheckLastBatch((0, 2), (1, 2)),
+ AddData(inputData, 1, 3, 5),
+ CheckLastBatch((1, 5))
+ )
+ }
+
+ testQuietly("midbatch failure") {
+ val inputData = MemoryStream[Int]
+ FailureSinglton.firstTime = true
+ val aggregated =
+ inputData.toDS()
+ .map { i =>
+ if (i == 4 && FailureSinglton.firstTime) {
+ FailureSinglton.firstTime = false
+ sys.error("injected failure")
+ }
+
+ i
+ }
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ StartStream,
+ AddData(inputData, 1, 2, 3, 4),
+ ExpectFailure[SparkException](),
+ StartStream,
+ CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
+ )
+ }
+
+ test("typed aggregators") {
+ def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+ new SumOf(f).toColumn
+
+ val inputData = MemoryStream[(String, Int)]
+ val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2))
+
+ testStream(aggregated)(
+ AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
+ CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
+ )
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 2bdb428e9d..ff40c366c8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
/**
* Planner that takes into account Hive-specific strategies.
*/
- override lazy val planner: SparkPlanner = {
- new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies {
+ override def planner: SparkPlanner = {
+ new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
+ with HiveStrategies {
override val hiveContext = ctx
override def strategies: Seq[Strategy] = {