diff options
33 files changed, 827 insertions, 305 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d82ee3a205..05e2b9a447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,11 @@ class Analyzer( Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( @@ -1153,11 +1158,11 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - AggregateExpression(function, mode, isDistinct), + ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val newAgg = AggregateExpression(newFunction, mode, isDistinct) + val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d1e892e32..4880502398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -76,7 +76,7 @@ trait CheckAnalysis { case g: GroupingID => failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0d44d1dd96..0420b4b538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff3064ac66..d31ccf9985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. @@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable { private[sql] case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) + isDistinct: Boolean, + resultId: ExprId) extends Expression with Unevaluable { + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) + override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 262582ca5d..2307122ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -329,7 +329,7 @@ case class PrettyAttribute( override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException + override def nullable: Boolean = true } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948ef1b..326933ec9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. - AggregateExpression(Count(Literal(1)), mode, isDistinct = false) + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c927077d0..28d2c445b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -216,3 +217,75 @@ object IntegerIndex { case _ => None } } + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index aa5d4330d3..7191936699 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 912b84abc1..4843553211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + // TODO: Move the planner an optimizer into here from SessionState. + protected def planner = sqlContext.sessionState.planner + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) @@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + protected def prepareForExecution(plan: SparkPlan): SparkPlan = { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + /** A sequence of rules that will be applied in order to the physical plan before execution. */ + protected def preparations: Seq[Rule[SparkPlan]] = Seq( + PlanSubqueries(sqlContext), + EnsureRequirements(sqlContext.conf), + CollapseCodegenStages(sqlContext.conf), + ReuseExchange(sqlContext.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 010ed7f500..b1b3d4ac81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan { override def producedAttributes: AttributeSet = outputSet } +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } +} + private[sql] trait UnaryNode extends SparkPlan { def child: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 9da2c74c62..ac8072f3ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val experimentalMethods: ExperimentalMethods) + val extraStrategies: Seq[Strategy]) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - experimentalMethods.extraStrategies ++ ( + extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a2e2b7382..5bcc172ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -204,28 +203,32 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** + * Used to plan aggregation queries that are computed incrementally as part of a + * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] + */ + object StatefulAggregationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalAggregation( + namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + + aggregate.Utils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + planLater(child)) + + case _ => Nil + } + } + + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap + case PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) => val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } - - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + "Spark user mailing list.") } val aggregateOperator = @@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 270c09aff3..7acf020b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -177,7 +177,7 @@ case class Window( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) case f => sys.error(s"Unsupported window function: $f") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 213bca907b..ce504e20e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -242,9 +242,9 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _) => + case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _) => + case agg @ AggregateExpression(_, Complete, _, _) => agg.copy(mode = Final) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1e113ccd4e..4682949fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -29,15 +30,11 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, @@ -83,7 +80,6 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -111,9 +107,7 @@ object Utils { val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), @@ -131,7 +125,6 @@ object Utils { groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -151,9 +144,7 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. @@ -169,9 +160,7 @@ object Utils { // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), @@ -190,7 +179,7 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } @@ -199,9 +188,7 @@ object Utils { val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because @@ -211,7 +198,7 @@ object Utils { val expr = AggregateExpression(func, Partial, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -232,9 +219,7 @@ object Utils { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => @@ -245,7 +230,7 @@ object Utils { val expr = AggregateExpression(func, Final, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -261,4 +246,90 @@ object Utils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestore(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + val saved = StateStoreSave(groupingAttributes, None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala new file mode 100644 index 0000000000..aaced49dd1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} + +/** + * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] + * plan incrementally. Possibly preserving state in between each execution. + */ +class IncrementalExecution( + ctx: SQLContext, + logicalPlan: LogicalPlan, + checkpointLocation: String, + currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + + // TODO: make this always part of planning. + val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + + // Modified planner with stateful operations. + override def planner: SparkPlanner = + new SparkPlanner( + sqlContext.sparkContext, + sqlContext.conf, + stateStrategy) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preperation walks the query plan. + */ + private var operatorId = 0 + + /** Locates save/restore pairs surrounding aggregation. */ + val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { + case StateStoreSave(keys, None, + UnaryNode(agg, + StateStoreRestore(keys2, None, child))) => + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + operatorId += 1 + + StateStoreSave( + keys, + Some(stateId), + agg.withNewChildren( + StateStoreRestore( + keys, + Some(stateId), + child) :: Nil)) + } + } + + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala new file mode 100644 index 0000000000..595774761c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestore( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + row +: savedState.toSeq + } + } + } + override def output: Seq[Attribute] = child.output +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSave( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c4e410d92c..511e30c70c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -272,6 +273,8 @@ class StreamExecution( private def runBatch(): Unit = { val startTime = System.nanoTime() + // TODO: Move this to IncrementalExecution. + // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => @@ -305,13 +308,14 @@ class StreamExecution( } val optimizerStart = System.nanoTime() - - lastExecution = new QueryExecution(sqlContext, newPlan) - val executedPlan = lastExecution.executedPlan + lastExecution = + new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution.executedPlan val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") - val nextBatch = Dataset.ofRows(sqlContext, newPlan) + val nextBatch = + new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 0f91e59e04..7d97f81b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ private val batches = new ArrayBuffer[Array[Row]]() @@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.flatten } + def lastBatch: Seq[Row] = batches.last + def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => val dataStr = try b.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ee015baf3f..998eb82de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider( trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE - case object CANCELLED extends STATE + case object ABORTED extends STATE private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") @@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot update after already committed or cancelled") - val oldValueOption = Option(mapToUpdate.get(key)) - val value = updateFunc(oldValueOption) + override def get(key: UnsafeRow): Option[UnsafeRow] = { + Option(mapToUpdate.get(key)) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + + val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) Option(allUpdates.get(key)) match { @@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider( case None => // There was no prior update, so mark this as added or updated according to its presence // in previous version. - val update = - if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) allUpdates.put(key, update) } writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) @@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or cancelled") try { finalizeDeltaFile(tempDeltaFileStream) @@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** Cancel all the updates made on this store. This store will not be usable any more. */ - override def cancel(): Unit = { - state = CANCELLED + override def abort(): Unit = { + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() } @@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Get an iterator of all the store data. This can be called only after committing the - * updates. + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") @@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing the updates. + * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { verify(state == COMMITTED, "Cannot get iterator of updates before committing") @@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override def hasCommitted: Boolean = { + override private[state] def hasCommitted: Boolean = { state == COMMITTED } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ca5c864d9e..d60e6185ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -47,12 +47,11 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + /** Get the current value of a key. */ + def get(key: UnsafeRow): Option[UnsafeRow] + + /** Put a new value for a key. */ + def put(key: UnsafeRow, value: UnsafeRow) /** * Remove keys that match the following condition. @@ -65,24 +64,24 @@ trait StateStore { def commit(): Long /** Cancel all the updates that have been made to the store. */ - def cancel(): Unit + def abort(): Unit /** * Iterator of store data after a set of updates have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def iterator(): Iterator[(UnsafeRow, UnsafeRow)] /** * Iterator of the updates that have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def updates(): Iterator[StoreUpdate] /** * Whether all updates have been committed */ - def hasCommitted: Boolean + private[state] def hasCommitted: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index cca22a0af8..f0f1f3a1a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } -private[state] object StateStoreConf { +private[streaming] object StateStoreConf { val empty = new StateStoreConf() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3318660895..df3d82c113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - - Utils.tryWithSafeFinally { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) - store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) - val inputIter = dataRDD.iterator(partition, ctxt) - val outputIter = storeUpdateFunction(store, inputIter) - assert(store.hasCommitted) - outputIter - } { - if (store != null) store.cancel() - } + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(store, inputIter) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b249e37921..9b6d0918e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,37 +28,36 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { /** Map each partition of a RDD along with data in a [[StateStore]]. */ - def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + def mapPartitionsWithStateStore[U: ClassTag]( + sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, - valueSchema: StructType - )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + valueSchema: StructType)( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { - mapPartitionWithStateStore( - storeUpdateFunction, + mapPartitionsWithStateStore( checkpointLocation, operatorId, storeVersion, keySchema, valueSchema, new StateStoreConf(sqlContext.conf), - Some(sqlContext.streams.stateStoreCoordinator)) + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) } /** Map each partition of a RDD along with data in a [[StateStore]]. */ - private[state] def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, valueSchema: StructType, storeConf: StateStoreConf, - storeCoordinator: Option[StateStoreCoordinatorRef] - ): StateStoreRDD[T, U] = { + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( dataRDD, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 0d580703f5..4b3091ba22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -60,14 +60,13 @@ case class ScalarSubquery( } /** - * Convert the subquery from logical plan into executed plan. + * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ -case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) + val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 844f3051fa..9cb356f1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable { implicit bEncoder: Encoder[B], cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = - new AggregateExpression( + AggregateExpression( TypedAggregateExpression(this), Complete, - false) + isDistinct = false) new TypedColumn[I, O](expr, encoderFor[O]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f7fdfacd31..cd3d254d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) - - /** - * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - lazy val prepareForExecution = new RuleExecutor[SparkPlan] { - override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(SessionState.this)), - Batch("Add exchange", Once, EnsureRequirements(conf)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) - ) - } + def planner: SparkPlanner = + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index b5be7ef47e..550c3c6f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row]) + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } /** Stops the stream. It must currently be running. */ @@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts { """.stripMargin def verify(condition: => Boolean, message: String): Unit = { - try { - Assertions.assert(condition) - } catch { - case NonFatal(e) => - failTest(message, e) + if (!condition) { + failTest(message) } } @@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts { case a: AddData => awaiting.put(a.source, a.addData()) - case CheckAnswerRows(expectedAnswer) => + case CheckAnswerRows(expectedAnswer, lastOnly) => verify(currentStream != null, "stream not running") // Block until all data added has been processed @@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val allData = try sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, allData).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ed0d3f56e5..38318740a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -243,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 85db05157c..6be94eb24f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { @@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - quietly { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContet = new SQLContext(sc) - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } - val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + withSpark(new SparkContext(sparkConf)) { sc => + val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = + makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } - // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + } - // Make sure the previous RDD still has the same data. - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } } - test("recovering from files") { - quietly { - val opId = 0 + test("usage with iterators - only gets and only puts") { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 - def makeStoreRDD( - sc: SparkContext, - seq: Seq[String], - storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema) + // Returns an iterator of the incremented value made into the store + def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { + val resIterator = iter.map { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val newValue = oldValue + 1 + store.put(key, intToRow(newValue)) + (s, newValue) + } + CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, { + store.commit() + }) } - // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => - for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + def iteratorOfGets( + store: StateStore, + iter: Iterator[String]): Iterator[(String, Option[Int])] = { + iter.map { s => + val key = stringToRow(s) + val value = store.get(key).map(rowToInt) + (s, value) } } - // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) - } + val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) + + val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) + + val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) assert( @@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContet = new SQLContext(sc) + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + store.put(key, intToRow(oldValue + 1)) } store.commit() store.iterator().map(rowsToStringInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 22b2f4f75d..0e5936d53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.stop() } - test("update, remove, commit, and all data iterator") { + test("get, put, remove, commit, and all data iterator") { val provider = newStoreProvider() // Verify state before starting a new set of updates @@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify state after updating - update(store, "a", 1) + put(store, "a", 1) intercept[IllegalStateException] { store.iterator() } @@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(provider.latestIterator().isEmpty) // Make updates, commit and then verify state - update(store, "b", 2) - update(store, "aa", 3) + put(store, "b", 2) + put(store, "aa", 3) remove(store, _.startsWith("a")) assert(store.commit() === 1) @@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) - update(reloadedStore, "c", 4) + put(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { val store = provider.getStore(currentVersion) body(store) @@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // New data should be seen in updates as value added, even if they had multiple updates withStore { store => - update(store, "a", 1) - update(store, "aa", 1) - update(store, "aa", 2) + put(store, "a", 1) + put(store, "aa", 1) + put(store, "aa", 2) store.commit() assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) @@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Multiple updates to same key should be collapsed in the updates as a single value update // Keys that have not been updated should not appear in the updates withStore { store => - update(store, "a", 4) - update(store, "a", 6) + put(store, "a", 4) + put(store, "a", 6) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) @@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Keys added, updated and finally removed before commit should not appear in updates withStore { store => - update(store, "b", 4) // Added, finally removed - update(store, "bb", 5) // Added, updated, finally removed - update(store, "bb", 6) + put(store, "b", 4) // Added, finally removed + put(store, "bb", 5) // Added, updated, finally removed + put(store, "bb", 6) remove(store, _.startsWith("b")) store.commit() assert(updatesToSet(store.updates()) === Set.empty) @@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Removed, but re-added data should be seen in updates as a value update withStore { store => remove(store, _.startsWith("a")) - update(store, "a", 10) + put(store, "a", 10) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) assert(rowsToSet(store.iterator()) === Set("a" -> 10)) @@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("cancel") { val provider = newStoreProvider() val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) store.commit() assert(rowsToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) - update(store1, "b", 1) - store1.cancel() + put(store1, "b", 1) + store1.abort() assert(getDataFromFiles(provider) === Set("a" -> 1)) } @@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Prepare some data in the stoer val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) assert(store.commit() === 1) assert(rowsToSet(store.iterator()) === Set("a" -> 1)) @@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Update store version with some data val store1 = provider.getStore(1) - update(store1, "b", 1) + put(store1, "b", 1) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data val store2 = provider.getStore(1) - update(store2, "c", 1) + put(store2, "c", 1) assert(store2.commit() === 2) assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) @@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { val store = provider.getStore(currentVersion) - update(store, "a", i) + put(store, "a", i) store.commit() currentVersion += 1 } @@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Increase version of the store val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) assert(store0.version === 0) - update(store0, "a", 1) + put(store0, "a", 1) store0.commit() assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) @@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - update(store1, "a", 2) + put(store1, "a", 2) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) } @@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = StateStore.get( storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - update(store, "a", i) + put(store, "a", i) store.commit() } eventually(timeout(10 seconds)) { @@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.remove(row => condition(rowToString(row))) } - private def update(store: StateStore, key: String, value: Int): Unit = { - store.update(stringToRow(key), _ => intToRow(value)) + private def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) + } + + private def get(store: StateStore, key: String): Option[Int] = { + store.get(stringToRow(key)).map(rowToInt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala new file mode 100644 index 0000000000..b63ce89d18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +object FailureSinglton { + var firstTime = true +} + +class StreamingAggregationSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("simple count") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream, + StartStream, + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) + ) + } + + test("multiple aggregations") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*") as 'count) + .groupBy($"value" % 2) + .agg(sum($"count")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2, 3, 4), + CheckLastBatch((0, 2), (1, 2)), + AddData(inputData, 1, 3, 5), + CheckLastBatch((1, 5)) + ) + } + + testQuietly("midbatch failure") { + val inputData = MemoryStream[Int] + FailureSinglton.firstTime = true + val aggregated = + inputData.toDS() + .map { i => + if (i == 4 && FailureSinglton.firstTime) { + FailureSinglton.firstTime = false + sys.error("injected failure") + } + + i + } + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + StartStream, + AddData(inputData, 1, 2, 3, 4), + ExpectFailure[SparkException](), + StartStream, + CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) + ) + } + + test("typed aggregators") { + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2)) + + testStream(aggregated)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 2bdb428e9d..ff40c366c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Planner that takes into account Hive-specific strategies. */ - override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { + override def planner: SparkPlanner = { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { |