aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala100
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala183
3 files changed, 136 insertions, 149 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
deleted file mode 100644
index f1230e7526..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.language.implicitConversions
-
-import scala.reflect._
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.{Aggregator, InterruptibleIterator, Logging}
-import org.apache.spark.util.collection.AppendOnlyMap
-
-/* Implicit conversions */
-import org.apache.spark.SparkContext._
-
-/**
- * Extra functions on RDDs that perform only local operations. These can be used when data has
- * already been partitioned correctly.
- */
-private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
- extends Logging
- with Serializable {
-
- /**
- * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have
- * the same number of partitions. Partitions of these two RDDs are cogrouped
- * according to the indexes of partitions. If we have two RDDs and
- * each of them has n partitions, we will cogroup the partition i from `this`
- * with the partition i from `other`.
- * This function will not introduce a shuffling operation.
- */
- def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
- val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => {
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
-
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 }
- iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 }
-
- map.iterator
- }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])}
-
- cg
- }
-
- /**
- * Group the values for each key within a partition of the RDD into a single sequence.
- * This function will not introduce a shuffling operation.
- */
- def groupByKeyLocally(): RDD[(K, Seq[V])] = {
- def createCombiner(v: V) = ArrayBuffer(v)
- def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _)
- val bufs = self.mapPartitionsWithContext((context, iter) => {
- new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
- }, preservesPartitioning = true)
- bufs.asInstanceOf[RDD[(K, Seq[V])]]
- }
-
- /**
- * Join corresponding partitions of `this` and `other`.
- * If we have two RDDs and each of them has n partitions,
- * we will join the partition i from `this` with the partition i from `other`.
- * This function will not introduce a shuffling operation.
- */
- def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
- cogroupLocally(other).flatMapValues {
- case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
- }
- }
-}
-
-private[spark] object PartitionLocalRDDFunctions {
- implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
- new PartitionLocalRDDFunctions(rdd)
-}
-
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 869673b1fe..450c142c0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -76,7 +76,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
*/
object AddExchange extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- val numPartitions = 8
+ val numPartitions = 150
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
index 8515a18f18..2a4f7b5670 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.execution
+import java.util.HashMap
+
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-/* Implicit conversions */
-import org.apache.spark.rdd.PartitionLocalRDDFunctions._
-
/**
* Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
* group.
@@ -40,7 +39,7 @@ case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
- extends UnaryNode {
+ extends UnaryNode with NoBind {
override def requiredChildDistribution =
if (partial) {
@@ -55,61 +54,149 @@ case class Aggregate(
override def otherCopyArgs = sc :: Nil
+ // HACK: Generators don't correctly preserve their output through serializations so we grab
+ // out child's output attributes statically here.
+ val childOutput = child.output
+
def output = aggregateExpressions.map(_.toAttribute)
- /* Replace all aggregate expressions with spark functions that will compute the result. */
- def createAggregateImplementations() = aggregateExpressions.map { agg =>
- val impl = agg transform {
- case a: AggregateExpression => a.newInstance
+ /**
+ * An aggregate that needs to be computed for each row in a group.
+ *
+ * @param unbound Unbound version of this aggregate, used for result substitution.
+ * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer.
+ * @param resultAttribute An attribute used to refer to the result of this aggregate in the final
+ * output.
+ */
+ case class ComputedAggregate(
+ unbound: AggregateExpression,
+ aggregate: AggregateExpression,
+ resultAttribute: AttributeReference)
+
+ /** A list of aggregates that need to be computed for each group. */
+ @transient
+ lazy val computedAggregates = aggregateExpressions.flatMap { agg =>
+ agg.collect {
+ case a: AggregateExpression =>
+ ComputedAggregate(
+ a,
+ BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
+ AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
}
+ }.toArray
+
+ /** The schema of the result of all aggregate evaluations */
+ @transient
+ lazy val computedSchema = computedAggregates.map(_.resultAttribute)
+
+ /** Creates a new aggregate buffer for a group. */
+ def newAggregateBuffer(): Array[AggregateFunction] = {
+ val buffer = new Array[AggregateFunction](computedAggregates.length)
+ var i = 0
+ while (i < computedAggregates.length) {
+ buffer(i) = computedAggregates(i).aggregate.newInstance()
+ i += 1
+ }
+ buffer
+ }
- val remainingAttributes = impl.collect { case a: Attribute => a }
- // If any references exist that are not inside agg functions then the must be grouping exprs
- // in this case we must rebind them to the grouping tuple.
- if (remainingAttributes.nonEmpty) {
- val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c }
-
- // An exact match with a grouping expression
- val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match {
- case -1 => None
- case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute))
- }
+ /** Named attributes used to substitute grouping attributes into the final result. */
+ @transient
+ lazy val namedGroups = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
+ }
- exactGroupingExpr.getOrElse(
- sys.error(s"$agg is not in grouping expressions: $groupingExpressions"))
- } else {
- impl
+ /**
+ * A map of substitutions that are used to insert the aggregate expressions and grouping
+ * expression into the final result expression.
+ */
+ @transient
+ lazy val resultMap =
+ (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
+
+ /**
+ * Substituted version of aggregateExpressions expressions which are used to compute final
+ * output rows given a group and the result of all aggregate computations.
+ */
+ @transient
+ lazy val resultExpressions = aggregateExpressions.map { agg =>
+ agg.transform {
+ case e: Expression if resultMap.contains(e) => resultMap(e)
}
}
def execute() = attachTree(this, "execute") {
- // TODO: If the child of it is an [[catalyst.execution.Exchange]],
- // do not evaluate the groupingExpressions again since we have evaluated it
- // in the [[catalyst.execution.Exchange]].
- val grouped = child.execute().mapPartitions { iter =>
- val buildGrouping = new Projection(groupingExpressions)
- iter.map(row => (buildGrouping(row), row.copy()))
- }.groupByKeyLocally()
-
- val result = grouped.map { case (group, rows) =>
- val aggImplementations = createAggregateImplementations()
-
- // Pull out all the functions so we can feed each row into them.
- val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f })
-
- rows.foreach { row =>
- aggFunctions.foreach(_.update(row))
+ if (groupingExpressions.isEmpty) {
+ child.execute().mapPartitions { iter =>
+ val buffer = newAggregateBuffer()
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ var i = 0
+ while (i < buffer.length) {
+ buffer(i).update(currentRow)
+ i += 1
+ }
+ }
+ val resultProjection = new Projection(resultExpressions, computedSchema)
+ val aggregateResults = new GenericMutableRow(computedAggregates.length)
+
+ var i = 0
+ while (i < buffer.length) {
+ aggregateResults(i) = buffer(i).apply(EmptyRow)
+ i += 1
+ }
+
+ Iterator(resultProjection(aggregateResults))
}
- buildRow(aggImplementations.map(_.apply(group)))
- }
-
- // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY...
- if (groupingExpressions.isEmpty && result.count == 0) {
- // When there there is no output to the Aggregate operator, we still output an empty row.
- val aggImplementations = createAggregateImplementations()
- sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil)
} else {
- result
+ child.execute().mapPartitions { iter =>
+ val hashTable = new HashMap[Row, Array[AggregateFunction]]
+ val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ val currentGroup = groupingProjection(currentRow)
+ var currentBuffer = hashTable.get(currentGroup)
+ if (currentBuffer == null) {
+ currentBuffer = newAggregateBuffer()
+ hashTable.put(currentGroup.copy(), currentBuffer)
+ }
+
+ var i = 0
+ while (i < currentBuffer.length) {
+ currentBuffer(i).update(currentRow)
+ i += 1
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] val hashTableIter = hashTable.entrySet().iterator()
+ private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
+ private[this] val resultProjection =
+ new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
+ private[this] val joinedRow = new JoinedRow
+
+ override final def hasNext: Boolean = hashTableIter.hasNext
+
+ override final def next(): Row = {
+ val currentEntry = hashTableIter.next()
+ val currentGroup = currentEntry.getKey
+ val currentBuffer = currentEntry.getValue
+
+ var i = 0
+ while (i < currentBuffer.length) {
+ // Evaluating an aggregate buffer returns the result. No row is required since we
+ // already added all rows in the group using update.
+ aggregateResults(i) = currentBuffer(i).apply(EmptyRow)
+ i += 1
+ }
+ resultProjection(joinedRow(aggregateResults, currentGroup))
+ }
+ }
+ }
}
}
}