aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-05 21:50:14 -0700
committerReynold Xin <rxin@databricks.com>2015-08-05 21:50:14 -0700
commit9270bd06fd0b16892e3f37213b5bc7813ea11fdd (patch)
treecbb0dd37331c11e231706e7bdaf5f9d83caca908
parent119b59053870df7be899bf5c1c0d321406af96f9 (diff)
downloadspark-9270bd06fd0b16892e3f37213b5bc7813ea11fdd.tar.gz
spark-9270bd06fd0b16892e3f37213b5bc7813ea11fdd.tar.bz2
spark-9270bd06fd0b16892e3f37213b5bc7813ea11fdd.zip
[SPARK-9674][SQL] Remove GeneratedAggregate.
The new aggregate replaces the old GeneratedAggregate. Author: Reynold Xin <rxin@databricks.com> Closes #7983 from rxin/remove-generated-agg and squashes the following commits: 8334aae [Reynold Xin] [SPARK-9674][SQL] Remove GeneratedAggregate.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala352
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala48
4 files changed, 2 insertions, 437 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
deleted file mode 100644
index bf4905dc1e..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ /dev/null
@@ -1,352 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io.IOException
-
-import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.trees._
-import org.apache.spark.sql.types._
-
-case class AggregateEvaluation(
- schema: Seq[Attribute],
- initialValues: Seq[Expression],
- update: Seq[Expression],
- result: Expression)
-
-/**
- * :: DeveloperApi ::
- * Alternate version of aggregation that leverages projection and thus code generation.
- * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
- * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
- *
- * @param partial if true then aggregation is done partially on local data without shuffling to
- * ensure all values where `groupingExpressions` are equal are present.
- * @param groupingExpressions expressions that are evaluated to determine grouping.
- * @param aggregateExpressions expressions that are computed for each group.
- * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
- * @param child the input data source.
- */
-@DeveloperApi
-case class GeneratedAggregate(
- partial: Boolean,
- groupingExpressions: Seq[Expression],
- aggregateExpressions: Seq[NamedExpression],
- unsafeEnabled: Boolean,
- child: SparkPlan)
- extends UnaryNode {
-
- override def requiredChildDistribution: Seq[Distribution] =
- if (partial) {
- UnspecifiedDistribution :: Nil
- } else {
- if (groupingExpressions == Nil) {
- AllTuples :: Nil
- } else {
- ClusteredDistribution(groupingExpressions) :: Nil
- }
- }
-
- override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = {
- val aggregatesToCompute = aggregateExpressions.flatMap { a =>
- a.collect { case agg: AggregateExpression1 => agg}
- }
-
- // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
- // (in test "aggregation with codegen").
- val computeFunctions = aggregatesToCompute.map {
- case c @ Count(expr) =>
- // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
- // UnscaledValue will be null if and only if x is null; helps with Average on decimals
- val toCount = expr match {
- case UnscaledValue(e) => e
- case _ => expr
- }
- val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
- val initialValue = Literal(0L)
- val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
- val result = currentCount
-
- AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
- case s @ Sum(expr) =>
- val calcType =
- expr.dataType match {
- case DecimalType.Fixed(p, s) =>
- DecimalType.bounded(p + 10, s)
- case _ =>
- expr.dataType
- }
-
- val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
- val initialValue = Literal.create(null, calcType)
-
- // Coalesce avoids double calculation...
- // but really, common sub expression elimination would be better....
- val zero = Cast(Literal(0), calcType)
- val updateFunction = Coalesce(
- Add(
- Coalesce(currentSum :: zero :: Nil),
- Cast(expr, calcType)
- ) :: currentSum :: Nil)
- val result =
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(currentSum, s.dataType)
- case _ => currentSum
- }
-
- AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
- case m @ Max(expr) =>
- val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
- val initialValue = Literal.create(null, expr.dataType)
- val updateMax = MaxOf(currentMax, expr)
-
- AggregateEvaluation(
- currentMax :: Nil,
- initialValue :: Nil,
- updateMax :: Nil,
- currentMax)
-
- case m @ Min(expr) =>
- val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
- val initialValue = Literal.create(null, expr.dataType)
- val updateMin = MinOf(currentMin, expr)
-
- AggregateEvaluation(
- currentMin :: Nil,
- initialValue :: Nil,
- updateMin :: Nil,
- currentMin)
-
- case CollectHashSet(Seq(expr)) =>
- val set =
- AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
- val initialValue = NewSet(expr.dataType)
- val addToSet = AddItemToSet(expr, set)
-
- AggregateEvaluation(
- set :: Nil,
- initialValue :: Nil,
- addToSet :: Nil,
- set)
-
- case CombineSetsAndCount(inputSet) =>
- val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
- val set =
- AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
- val initialValue = NewSet(elementType)
- val collectSets = CombineSets(set, inputSet)
-
- AggregateEvaluation(
- set :: Nil,
- initialValue :: Nil,
- collectSets :: Nil,
- CountSet(set))
-
- case o => sys.error(s"$o can't be codegened.")
- }
-
- val computationSchema = computeFunctions.flatMap(_.schema)
-
- val resultMap: Map[TreeNodeRef, Expression] =
- aggregatesToCompute.zip(computeFunctions).map {
- case (agg, func) => new TreeNodeRef(agg) -> func.result
- }.toMap
-
- val namedGroups = groupingExpressions.zipWithIndex.map {
- case (ne: NamedExpression, _) => (ne, ne.toAttribute)
- case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
- }
-
- // The set of expressions that produce the final output given the aggregation buffer and the
- // grouping expressions.
- val resultExpressions = aggregateExpressions.map(_.transform {
- case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
- case e: Expression =>
- namedGroups.collectFirst {
- case (expr, attr) if expr semanticEquals e => attr
- }.getOrElse(e)
- })
-
- val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
-
- val groupKeySchema: StructType = {
- val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
- // This is a dummy field name
- StructField(idx.toString, expr.dataType, expr.nullable)
- }
- StructType(fields)
- }
-
- val schemaSupportsUnsafe: Boolean = {
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeProjection.canSupport(groupKeySchema)
- }
-
- child.execute().mapPartitions { iter =>
- // Builds a new custom class for holding the results of aggregation for a group.
- val initialValues = computeFunctions.flatMap(_.initialValues)
- val newAggregationBuffer = newProjection(initialValues, child.output)
- log.info(s"Initial values: ${initialValues.mkString(",")}")
-
- // A projection that computes the group given an input tuple.
- val groupProjection = newProjection(groupingExpressions, child.output)
- log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
-
- // A projection that is used to update the aggregate values for a group given a new tuple.
- // This projection should be targeted at the current values for the group and then applied
- // to a joined row of the current values with the new input row.
- val updateExpressions = computeFunctions.flatMap(_.update)
- val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
- val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
- log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
-
- // A projection that produces the final result, given a computation.
- val resultProjectionBuilder =
- newMutableProjection(
- resultExpressions,
- namedGroups.map(_._2) ++ computationSchema)
- log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
-
- val joinedRow = new JoinedRow
-
- if (!iter.hasNext) {
- // This is an empty input, so return early so that we do not allocate data structures
- // that won't be cleaned up (see SPARK-8357).
- if (groupingExpressions.isEmpty) {
- // This is a global aggregate, so return an empty aggregation buffer.
- val resultProjection = resultProjectionBuilder()
- Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
- } else {
- // This is a grouped aggregate, so return an empty iterator.
- Iterator[InternalRow]()
- }
- } else if (groupingExpressions.isEmpty) {
- // TODO: Codegening anything other than the updateProjection is probably over kill.
- val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
- var currentRow: InternalRow = null
- updateProjection.target(buffer)
-
- while (iter.hasNext) {
- currentRow = iter.next()
- updateProjection(joinedRow(buffer, currentRow))
- }
-
- val resultProjection = resultProjectionBuilder()
- Iterator(resultProjection(buffer))
-
- } else if (unsafeEnabled && schemaSupportsUnsafe) {
- assert(iter.hasNext, "There should be at least one row for this path")
- log.info("Using Unsafe-based aggregator")
- val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
- val taskContext = TaskContext.get()
- val aggregationMap = new UnsafeFixedWidthAggregationMap(
- newAggregationBuffer(EmptyRow),
- aggregationBufferSchema,
- groupKeySchema,
- taskContext.taskMemoryManager(),
- SparkEnv.get.shuffleMemoryManager,
- 1024 * 16, // initial capacity
- pageSizeBytes,
- false // disable tracking of performance metrics
- )
-
- while (iter.hasNext) {
- val currentRow: InternalRow = iter.next()
- val groupKey: InternalRow = groupProjection(currentRow)
- val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
- if (aggregationBuffer == null) {
- throw new IOException("Could not allocate memory to grow aggregation buffer")
- }
- updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
- }
-
- // Record memory used in the process
- taskContext.internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
-
- new Iterator[InternalRow] {
- private[this] val mapIterator = aggregationMap.iterator()
- private[this] val resultProjection = resultProjectionBuilder()
- private[this] var _hasNext = mapIterator.next()
-
- def hasNext: Boolean = _hasNext
-
- def next(): InternalRow = {
- if (_hasNext) {
- val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
- _hasNext = mapIterator.next()
- if (_hasNext) {
- result
- } else {
- // This is the last element in the iterator, so let's free the buffer. Before we do,
- // though, we need to make a defensive copy of the result so that we don't return an
- // object that might contain dangling pointers to the freed memory.
- val resultCopy = result.copy()
- aggregationMap.free()
- resultCopy
- }
- } else {
- throw new java.util.NoSuchElementException
- }
- }
- }
- } else {
- if (unsafeEnabled) {
- log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
- }
- val buffers = new java.util.HashMap[InternalRow, MutableRow]()
-
- var currentRow: InternalRow = null
- while (iter.hasNext) {
- currentRow = iter.next()
- val currentGroup = groupProjection(currentRow)
- var currentBuffer = buffers.get(currentGroup)
- if (currentBuffer == null) {
- currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
- buffers.put(currentGroup, currentBuffer)
- }
- // Target the projection at the current aggregation buffer and then project the updated
- // values.
- updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
- }
-
- new Iterator[InternalRow] {
- private[this] val resultIterator = buffers.entrySet.iterator()
- private[this] val resultProjection = resultProjectionBuilder()
-
- def hasNext: Boolean = resultIterator.hasNext
-
- def next(): InternalRow = {
- val currentGroup = resultIterator.next()
- resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
- }
- }
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 952ba7d45c..a730ffbb21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -136,32 +136,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Aggregations that can be performed in two phases, before and after the shuffle.
-
- // Cases where all aggregates can be codegened.
- case PartialAggregation(
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- groupingExpressions,
- partialComputation,
- child)
- if canBeCodeGened(
- allAggregates(partialComputation) ++
- allAggregates(rewrittenAggregateExpressions)) &&
- codegenEnabled &&
- !canBeConvertedToNewAggregation(plan) =>
- execution.GeneratedAggregate(
- partial = false,
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- unsafeEnabled,
- execution.GeneratedAggregate(
- partial = true,
- groupingExpressions,
- partialComputation,
- unsafeEnabled,
- planLater(child))) :: Nil
-
- // Cases where some aggregate can not be codegened
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
@@ -192,14 +166,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => false
}
- def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
- case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
- // The generated set implementation is pretty limited ATM.
- case CollectHashSet(exprs) if exprs.size == 1 &&
- Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
- case _ => false
- }
-
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 29dfcf2575..cef40dd324 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.SQLTestUtils
@@ -263,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = sql(sqlText)
// First, check if we have GeneratedAggregate.
val hasGeneratedAgg = df.queryExecution.executedPlan
- .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
+ .collect { case _: aggregate.Aggregate => true }
.nonEmpty
if (!hasGeneratedAgg) {
fail(
@@ -1603,7 +1602,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
- test("aggregation with codegen updates peak execution memory") {
+ ignore("aggregation with codegen updates peak execution memory") {
withSQLConf(
(SQLConf.CODEGEN_ENABLED.key, "true"),
(SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala
deleted file mode 100644
index 20def6bef0..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.test.TestSQLContext
-
-class AggregateSuite extends SparkPlanTest {
-
- test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
- val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
- val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
- try {
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
- TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
- val df = Seq.empty[(Int, Int)].toDF("a", "b")
- checkAnswer(
- df,
- GeneratedAggregate(
- partial = true,
- Seq(df.col("b").expr),
- Seq(Alias(Count(df.col("a").expr), "cnt")()),
- unsafeEnabled = true,
- _: SparkPlan),
- Seq.empty
- )
- } finally {
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
- TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
- }
- }
-}