From 14c9238aa7173ba663a999ef320d8cffb73306c4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 18:38:44 -0700 Subject: [sql] Rename execution/aggregates.scala Aggregate.scala, and added a bunch of private[this] to variables. Author: Reynold Xin Closes #348 from rxin/aggregate and squashes the following commits: f4bc36f [Reynold Xin] Rename execution/aggregates.scala Aggregate.scala, and added a bunch of private[this] to variables. --- .../org/apache/spark/sql/execution/Aggregate.scala | 202 +++++++++++++++++++++ .../apache/spark/sql/execution/aggregates.scala | 202 --------------------- 2 files changed, 202 insertions(+), 202 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala new file mode 100644 index 0000000000..3a4f071eeb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.HashMap + +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ + +/** + * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each + * group. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +case class Aggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sc: SparkContext) + extends UnaryNode with NoBind { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def otherCopyArgs = sc :: Nil + + // HACK: Generators don't correctly preserve their output through serializations so we grab + // out child's output attributes statically here. + private[this] val childOutput = child.output + + override def output = aggregateExpressions.map(_.toAttribute) + + /** + * An aggregate that needs to be computed for each row in a group. + * + * @param unbound Unbound version of this aggregate, used for result substitution. + * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. + * @param resultAttribute An attribute used to refer to the result of this aggregate in the final + * output. + */ + case class ComputedAggregate( + unbound: AggregateExpression, + aggregate: AggregateExpression, + resultAttribute: AttributeReference) + + /** A list of aggregates that need to be computed for each group. */ + @transient + private[this] lazy val computedAggregates = aggregateExpressions.flatMap { agg => + agg.collect { + case a: AggregateExpression => + ComputedAggregate( + a, + BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression], + AttributeReference(s"aggResult:$a", a.dataType, nullable = true)()) + } + }.toArray + + /** The schema of the result of all aggregate evaluations */ + @transient + private[this] lazy val computedSchema = computedAggregates.map(_.resultAttribute) + + /** Creates a new aggregate buffer for a group. */ + private[this] def newAggregateBuffer(): Array[AggregateFunction] = { + val buffer = new Array[AggregateFunction](computedAggregates.length) + var i = 0 + while (i < computedAggregates.length) { + buffer(i) = computedAggregates(i).aggregate.newInstance() + i += 1 + } + buffer + } + + /** Named attributes used to substitute grouping attributes into the final result. */ + @transient + private[this] lazy val namedGroups = groupingExpressions.map { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute + } + + /** + * A map of substitutions that are used to insert the aggregate expressions and grouping + * expression into the final result expression. + */ + @transient + private[this] lazy val resultMap = + (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap + + /** + * Substituted version of aggregateExpressions expressions which are used to compute final + * output rows given a group and the result of all aggregate computations. + */ + @transient + private[this] lazy val resultExpressions = aggregateExpressions.map { agg => + agg.transform { + case e: Expression if resultMap.contains(e) => resultMap(e) + } + } + + override def execute() = attachTree(this, "execute") { + if (groupingExpressions.isEmpty) { + child.execute().mapPartitions { iter => + val buffer = newAggregateBuffer() + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + var i = 0 + while (i < buffer.length) { + buffer(i).update(currentRow) + i += 1 + } + } + val resultProjection = new Projection(resultExpressions, computedSchema) + val aggregateResults = new GenericMutableRow(computedAggregates.length) + + var i = 0 + while (i < buffer.length) { + aggregateResults(i) = buffer(i).eval(EmptyRow) + i += 1 + } + + Iterator(resultProjection(aggregateResults)) + } + } else { + child.execute().mapPartitions { iter => + val hashTable = new HashMap[Row, Array[AggregateFunction]] + val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupingProjection(currentRow) + var currentBuffer = hashTable.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregateBuffer() + hashTable.put(currentGroup.copy(), currentBuffer) + } + + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).update(currentRow) + i += 1 + } + } + + new Iterator[Row] { + private[this] val hashTableIter = hashTable.entrySet().iterator() + private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) + private[this] val resultProjection = + new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = hashTableIter.hasNext + + override final def next(): Row = { + val currentEntry = hashTableIter.next() + val currentGroup = currentEntry.getKey + val currentBuffer = currentEntry.getValue + + var i = 0 + while (i < currentBuffer.length) { + // Evaluating an aggregate buffer returns the result. No row is required since we + // already added all rows in the group using update. + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroup)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala deleted file mode 100644 index 0890faa33b..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sc: SparkContext) - extends UnaryNode with NoBind { - - override def requiredChildDistribution = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def otherCopyArgs = sc :: Nil - - // HACK: Generators don't correctly preserve their output through serializations so we grab - // out child's output attributes statically here. - val childOutput = child.output - - def output = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - @transient - lazy val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression => - ComputedAggregate( - a, - BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression], - AttributeReference(s"aggResult:$a", a.dataType, nullable = true)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - @transient - lazy val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - @transient - lazy val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - @transient - lazy val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - @transient - lazy val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - def execute() = attachTree(this, "execute") { - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new Projection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[Row] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): Row = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} -- cgit v1.2.3