aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-06 16:04:20 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-06 16:04:20 -0800
commit6d0ead322e72303c6444c6ac641378a4690cde96 (patch)
tree649387cb64a232032ddc4062b787f31aef62a900
parent1ab72b08601a1c8a674bdd3fab84d9804899b2c7 (diff)
downloadspark-6d0ead322e72303c6444c6ac641378a4690cde96.tar.gz
spark-6d0ead322e72303c6444c6ac641378a4690cde96.tar.bz2
spark-6d0ead322e72303c6444c6ac641378a4690cde96.zip
[SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule
The second PR for SPARK-9241, this adds support for multiple distinct columns to the new aggregation code path. This PR solves the multiple DISTINCT column problem by rewriting these Aggregates into an Expand-Aggregate-Aggregate combination. See the [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-9241) for some information on this. The advantages over the - competing - [first PR](https://github.com/apache/spark/pull/9280) are: - This can use the faster TungstenAggregate code path. - It is impossible to OOM due to an ```OpenHashSet``` allocating to much memory. However, this will multiply the number of input rows by the number of distinct clauses (plus one), and puts a lot more memory pressure on the aggregation code path itself. The location of this Rule is a bit funny, and should probably change when the old aggregation path is changed. cc yhuai - Could you also tell me where to add tests for this? Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #9406 from hvanhovell/SPARK-9241-rewriter.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala186
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
6 files changed, 238 insertions, 44 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 54df96cd24..ec0c8b483a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)
override val evaluateExpression = Cast(count, LongType)
+
+ override def defaultResult: Option[Literal] = Option(Literal(0L))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index 644c6211d5..39010c3be6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -41,7 +42,7 @@ object Utils {
private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate if supportsGroupingKeySchema(p) =>
- val converted = p.transformExpressionsDown {
+ val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Average(child),
@@ -144,7 +145,8 @@ object Utils {
aggregateFunction = aggregate.VarianceSamp(child),
mode = aggregate.Complete,
isDistinct = false)
- }
+ })
+
// Check if there is any expressions.AggregateExpression1 left.
// If so, we cannot convert this plan.
val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
@@ -156,6 +158,7 @@ object Utils {
}
// Check if there are multiple distinct columns.
+ // TODO remove this.
val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
@@ -213,3 +216,178 @@ object Utils {
case other => None
}
}
+
+/**
+ * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
+ * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * TODO Expression cannocalization
+ * TODO Eliminate foldable expressions from distinct clauses.
+ * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
+ * operator. Perhaps this is a good thing? It is much simpler to plan later on...
+ */
+object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case a: Aggregate => rewrite(a)
+ case p => p
+ }
+
+ def rewrite(a: Aggregate): Aggregate = {
+
+ // Collect all aggregate expressions.
+ val aggExpressions = a.aggregateExpressions.flatMap { e =>
+ e.collect {
+ case ae: AggregateExpression2 => ae
+ }
+ }
+
+ // Extract distinct aggregate expressions.
+ val distinctAggGroups = aggExpressions
+ .filter(_.isDistinct)
+ .groupBy(_.aggregateFunction.children.toSet)
+
+ // Only continue to rewrite if there is more than one distinct group.
+ if (distinctAggGroups.size > 1) {
+ // Create the attributes for the grouping id and the group by clause.
+ val gid = new AttributeReference("gid", IntegerType, false)()
+ val groupByMap = a.groupingExpressions.collect {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
+ }
+ val groupByAttrs = groupByMap.map(_._2)
+
+ // Functions used to modify aggregate functions and their inputs.
+ def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
+ def patchAggregateFunctionChildren(
+ af: AggregateFunction2,
+ id: Literal,
+ attrs: Map[Expression, Expression]): AggregateFunction2 = {
+ af.withNewChildren(af.children.map { case afc =>
+ evalWithinGroup(id, attrs(afc))
+ }).asInstanceOf[AggregateFunction2]
+ }
+
+ // Setup unique distinct aggregate children.
+ val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
+ val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
+ val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+
+ // Setup expand & aggregate operators for distinct aggregate expressions.
+ val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+ case ((group, expressions), i) =>
+ val id = Literal(i + 1)
+
+ // Expand projection
+ val projection = distinctAggChildren.map {
+ case e if group.contains(e) => e
+ case e => nullify(e)
+ } :+ id
+
+ // Final aggregate
+ val operators = expressions.map { e =>
+ val af = e.aggregateFunction
+ val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
+ (e, e.copy(aggregateFunction = naf, isDistinct = false))
+ }
+
+ (projection, operators)
+ }
+
+ // Setup expand for the 'regular' aggregate expressions.
+ val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+ val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+ val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
+
+ // Setup aggregates for 'regular' aggregate expressions.
+ val regularGroupId = Literal(0)
+ val regularAggOperatorMap = regularAggExprs.map { e =>
+ // Perform the actual aggregation in the initial aggregate.
+ val af = patchAggregateFunctionChildren(
+ e.aggregateFunction,
+ regularGroupId,
+ regularAggChildAttrMap)
+ val a = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+ // Get the result of the first aggregate in the last aggregate.
+ val b = AggregateExpression2(
+ aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
+ mode = Complete,
+ isDistinct = false)
+
+ // Some aggregate functions (COUNT) have the special property that they can return a
+ // non-null result without any input. We need to make sure we return a result in this case.
+ val c = af.defaultResult match {
+ case Some(lit) => Coalesce(Seq(b, lit))
+ case None => b
+ }
+
+ (e, a, c)
+ }
+
+ // Construct the regular aggregate input projection only if we need one.
+ val regularAggProjection = if (regularAggExprs.nonEmpty) {
+ Seq(a.groupingExpressions ++
+ distinctAggChildren.map(nullify) ++
+ Seq(regularGroupId) ++
+ regularAggChildren)
+ } else {
+ Seq.empty[Seq[Expression]]
+ }
+
+ // Construct the distinct aggregate input projections.
+ val regularAggNulls = regularAggChildren.map(nullify)
+ val distinctAggProjections = distinctAggOperatorMap.map {
+ case (projection, _) =>
+ a.groupingExpressions ++
+ projection ++
+ regularAggNulls
+ }
+
+ // Construct the expand operator.
+ val expand = Expand(
+ regularAggProjection ++ distinctAggProjections,
+ groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
+ a.child)
+
+ // Construct the first aggregate operator. This de-duplicates the all the children of
+ // distinct operators, and applies the regular aggregate operators.
+ val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+ val firstAggregate = Aggregate(
+ firstAggregateGroupBy,
+ firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+ expand)
+
+ // Construct the second aggregate
+ val transformations: Map[Expression, Expression] =
+ (distinctAggOperatorMap.flatMap(_._2) ++
+ regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+ val patchedAggExpressions = a.aggregateExpressions.map { e =>
+ e.transformDown {
+ case e: Expression =>
+ // The same GROUP BY clauses can have different forms (different names for instance) in
+ // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
+ // tricky. So we do a linear search for a semantically equal group by expression.
+ groupByMap
+ .find(ge => e.semanticEquals(ge._1))
+ .map(_._2)
+ .getOrElse(transformations.getOrElse(e, e))
+ }.asInstanceOf[NamedExpression]
+ }
+ Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+ } else {
+ a
+ }
+ }
+
+ private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+ private def expressionAttributePair(e: Expression) =
+ // We are creating a new reference here instead of reusing the attribute in case of a
+ // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
+ // children, in this case attribute reuse causes the input of the regular aggregate to bound to
+ // the (nulled out) input of the distinct aggregate.
+ e -> new AttributeReference(e.prettyName, e.dataType, true)()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index a2fab258fc..5c5b3d1ccd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
*/
def supportsPartial: Boolean = true
+ /**
+ * Result of the aggregate function when the input is empty. This is currently only used for the
+ * proper rewriting of distinct aggregate functions.
+ */
+ def defaultResult: Option[Literal] = None
+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 338c5193cb..d222dfa33a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
- if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
+ case a @ Aggregate(_, _, e @ Expand(_, _, child))
+ if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
+ a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 4cb67aacf3..fb963e2f8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -235,33 +235,17 @@ case class Window(
projectList ++ windowExpressions.map(_.toAttribute)
}
-/**
- * Apply the all of the GroupExpressions to every input row, hence we will get
- * multiple output rows for a input row.
- * @param bitmasks The bitmask set represents the grouping sets
- * @param groupByExprs The grouping by expressions
- * @param child Child operator
- */
-case class Expand(
- bitmasks: Seq[Int],
- groupByExprs: Seq[Expression],
- gid: Attribute,
- child: LogicalPlan) extends UnaryNode {
- override def statistics: Statistics = {
- val sizeInBytes = child.statistics.sizeInBytes * projections.length
- Statistics(sizeInBytes = sizeInBytes)
- }
-
- val projections: Seq[Seq[Expression]] = expand()
-
+private[sql] object Expand {
/**
- * Extract attribute set according to the grouping id
+ * Extract attribute set according to the grouping id.
+ *
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
- private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
- : OpenHashSet[Expression] = {
+ private def buildNonSelectExprSet(
+ bitmask: Int,
+ exprs: Seq[Expression]): OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
var bit = exprs.length - 1
@@ -274,18 +258,28 @@ case class Expand(
}
/**
- * Create an array of Projections for the child projection, and replace the projections'
- * expressions which equal GroupBy expressions with Literal(null), if those expressions
- * are not set for this grouping set (according to the bit mask).
+ * Apply the all of the GroupExpressions to every input row, hence we will get
+ * multiple output rows for a input row.
+ *
+ * @param bitmasks The bitmask set represents the grouping sets
+ * @param groupByExprs The grouping by expressions
+ * @param gid Attribute of the grouping id
+ * @param child Child operator
*/
- private[this] def expand(): Seq[Seq[Expression]] = {
- val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
-
- bitmasks.foreach { bitmask =>
+ def apply(
+ bitmasks: Seq[Int],
+ groupByExprs: Seq[Expression],
+ gid: Attribute,
+ child: LogicalPlan): Expand = {
+ // Create an array of Projections for the child projection, and replace the projections'
+ // expressions which equal GroupBy expressions with Literal(null), if those expressions
+ // are not set for this grouping set (according to the bit mask).
+ val projections = bitmasks.map { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
- val substitution = (child.output :+ gid).map(expr => expr transformDown {
+ (child.output :+ gid).map(expr => expr transformDown {
+ // TODO this causes a problem when a column is used both for grouping and aggregation.
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
@@ -294,15 +288,29 @@ case class Expand(
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})
-
- result += substitution
}
-
- result.toSeq
+ Expand(projections, child.output :+ gid, child)
}
+}
- override def output: Seq[Attribute] = {
- child.output :+ gid
+/**
+ * Apply a number of projections to every input row, hence we will get multiple output rows for
+ * a input row.
+ *
+ * @param projections to apply
+ * @param output of all projections.
+ * @param child operator.
+ */
+case class Expand(
+ projections: Seq[Seq[Expression]],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+
+ override def statistics: Statistics = {
+ // TODO shouldn't we factor in the size of the projection versus the size of the backing child
+ // row?
+ val sizeInBytes = child.statistics.sizeInBytes * projections.length
+ Statistics(sizeInBytes = sizeInBytes)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f4464e0b91..dd3bb33c57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
- case e @ logical.Expand(_, _, _, child) =>
+ case e @ logical.Expand(_, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled