aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala292
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala206
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala100
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala173
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala749
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala364
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala280
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala26
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala1
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala8
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java107
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java100
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada1
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc471
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae414
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e1
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c31
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala507
39 files changed, 3087 insertions, 100 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index d4ef04c229..c04bd6cd85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
- { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
+ case name => UnresolvedFunction(name, exprs, isDistinct = true)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e58f3f6494..8cadbc57e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -277,7 +278,7 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
- case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
@@ -517,9 +518,26 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
- case u @ UnresolvedFunction(name, children) =>
+ case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
- registry.lookupFunction(name, children)
+ registry.lookupFunction(name, children) match {
+ // We get an aggregate function built based on AggregateFunction2 interface.
+ // So, we wrap it in AggregateExpression2.
+ case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
+ // Currently, our old aggregate function interface supports SUM(DISTINCT ...)
+ // and COUTN(DISTINCT ...).
+ case sumDistinct: SumDistinct => sumDistinct
+ case countDistinct: CountDistinct => countDistinct
+ // DISTINCT is not meaningful with Max and Min.
+ case max: Max if isDistinct => max
+ case min: Min if isDistinct => min
+ // For other aggregate functions, DISTINCT keyword is not supported for now.
+ // Once we converted to the new code path, we will allow using DISTINCT keyword.
+ case other if isDistinct =>
+ failAnalysis(s"$name does not support DISTINCT keyword.")
+ // If it does not have DISTINCT keyword, we will return it as is.
+ case other => other
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c7f9713344..c203fcecf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 0daee1990a..03da45b09f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}
-case class UnresolvedFunction(name: String, children: Seq[Expression])
+case class UnresolvedFunction(
+ name: String,
+ children: Seq[Expression],
+ isDistinct: Boolean)
extends Expression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b09aea0331..b10a3c8774 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression with NamedExpression {
- override def toString: String = s"input[$ordinal]"
+ override def toString: String = s"input[$ordinal, $dataType]"
override def eval(input: InternalRow): Any = input(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index aada25276a..29ae47e842 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
- ve
+ // Add `this` in the comment.
+ ve.copy(s"/* $this */\n" + ve.code)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
new file mode 100644
index 0000000000..b924af4cc8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+case class Average(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = resultType
+
+ // Expected input data type.
+ // TODO: Once we remove the old code path, we can use our analyzer to cast NullType
+ // to the default data type of the NumericType.
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+ private val resultType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited => DecimalType.Unlimited
+ case _ => DoubleType
+ }
+
+ private val sumDataType = child.dataType match {
+ case _ @ DecimalType() => DecimalType.Unlimited
+ case _ => DoubleType
+ }
+
+ private val currentSum = AttributeReference("currentSum", sumDataType)()
+ private val currentCount = AttributeReference("currentCount", LongType)()
+
+ override val bufferAttributes = currentSum :: currentCount :: Nil
+
+ override val initialValues = Seq(
+ /* currentSum = */ Cast(Literal(0), sumDataType),
+ /* currentCount = */ Literal(0L)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentSum = */
+ Add(
+ currentSum,
+ Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+ /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ )
+
+ override val mergeExpressions = Seq(
+ /* currentSum = */ currentSum.left + currentSum.right,
+ /* currentCount = */ currentCount.left + currentCount.right
+ )
+
+ // If all input are nulls, currentCount will be 0 and we will get null after the division.
+ override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
+}
+
+case class Count(child: Expression) extends AlgebraicAggregate {
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = LongType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val currentCount = AttributeReference("currentCount", LongType)()
+
+ override val bufferAttributes = currentCount :: Nil
+
+ override val initialValues = Seq(
+ /* currentCount = */ Literal(0L)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ )
+
+ override val mergeExpressions = Seq(
+ /* currentCount = */ currentCount.left + currentCount.right
+ )
+
+ override val evaluateExpression = Cast(currentCount, LongType)
+}
+
+case class First(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // First is not a deterministic function.
+ override def deterministic: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val first = AttributeReference("first", child.dataType)()
+
+ override val bufferAttributes = first :: Nil
+
+ override val initialValues = Seq(
+ /* first = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* first = */ If(IsNull(first), child, first)
+ )
+
+ override val mergeExpressions = Seq(
+ /* first = */ If(IsNull(first.left), first.right, first.left)
+ )
+
+ override val evaluateExpression = first
+}
+
+case class Last(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Last is not a deterministic function.
+ override def deterministic: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val last = AttributeReference("last", child.dataType)()
+
+ override val bufferAttributes = last :: Nil
+
+ override val initialValues = Seq(
+ /* last = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* last = */ If(IsNull(child), last, child)
+ )
+
+ override val mergeExpressions = Seq(
+ /* last = */ If(IsNull(last.right), last.left, last.right)
+ )
+
+ override val evaluateExpression = last
+}
+
+case class Max(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val max = AttributeReference("max", child.dataType)()
+
+ override val bufferAttributes = max :: Nil
+
+ override val initialValues = Seq(
+ /* max = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
+ )
+
+ override val mergeExpressions = {
+ val greatest = Greatest(Seq(max.left, max.right))
+ Seq(
+ /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
+ )
+ }
+
+ override val evaluateExpression = max
+}
+
+case class Min(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val min = AttributeReference("min", child.dataType)()
+
+ override val bufferAttributes = min :: Nil
+
+ override val initialValues = Seq(
+ /* min = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
+ )
+
+ override val mergeExpressions = {
+ val least = Least(Seq(min.left, min.right))
+ Seq(
+ /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
+ )
+ }
+
+ override val evaluateExpression = min
+}
+
+case class Sum(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = resultType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+
+ private val resultType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited => DecimalType.Unlimited
+ case _ => child.dataType
+ }
+
+ private val sumDataType = child.dataType match {
+ case _ @ DecimalType() => DecimalType.Unlimited
+ case _ => child.dataType
+ }
+
+ private val currentSum = AttributeReference("currentSum", sumDataType)()
+
+ private val zero = Cast(Literal(0), sumDataType)
+
+ override val bufferAttributes = currentSum :: Nil
+
+ override val initialValues = Seq(
+ /* currentSum = */ Literal.create(null, sumDataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentSum = */
+ Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
+ )
+
+ override val mergeExpressions = {
+ val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
+ Seq(
+ /* currentSum = */
+ Coalesce(Seq(add, currentSum.left))
+ )
+ }
+
+ override val evaluateExpression = Cast(currentSum, resultType)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
new file mode 100644
index 0000000000..577ede73cb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/** The mode of an [[AggregateFunction1]]. */
+private[sql] sealed trait AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object Partial extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object PartialMerge extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function and the generate final result.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Final extends AggregateMode
+
+/**
+ * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * from original input rows without any partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Complete extends AggregateMode
+
+/**
+ * A place holder expressions used in code-gen, it does not change the corresponding value
+ * in the row.
+ */
+private[sql] case object NoOp extends Expression with Unevaluable {
+ override def nullable: Boolean = true
+ override def eval(input: InternalRow): Any = {
+ throw new TreeNodeException(
+ this, s"No function to evaluate expression. type: ${this.nodeName}")
+ }
+ override def dataType: DataType = NullType
+ override def children: Seq[Expression] = Nil
+}
+
+/**
+ * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
+ * @param aggregateFunction
+ * @param mode
+ * @param isDistinct
+ */
+private[sql] case class AggregateExpression2(
+ aggregateFunction: AggregateFunction2,
+ mode: AggregateMode,
+ isDistinct: Boolean) extends AggregateExpression {
+
+ override def children: Seq[Expression] = aggregateFunction :: Nil
+ override def dataType: DataType = aggregateFunction.dataType
+ override def foldable: Boolean = false
+ override def nullable: Boolean = aggregateFunction.nullable
+
+ override def references: AttributeSet = {
+ val childReferemces = mode match {
+ case Partial | Complete => aggregateFunction.references.toSeq
+ case PartialMerge | Final => aggregateFunction.bufferAttributes
+ }
+
+ AttributeSet(childReferemces)
+ }
+
+ override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
+}
+
+abstract class AggregateFunction2
+ extends Expression with ImplicitCastInputTypes {
+
+ self: Product =>
+
+ /** An aggregate function is not foldable. */
+ override def foldable: Boolean = false
+
+ /**
+ * The offset of this function's buffer in the underlying buffer shared with other functions.
+ */
+ var bufferOffset: Int = 0
+
+ /** The schema of the aggregation buffer. */
+ def bufferSchema: StructType
+
+ /** Attributes of fields in bufferSchema. */
+ def bufferAttributes: Seq[AttributeReference]
+
+ /** Clones bufferAttributes. */
+ def cloneBufferAttributes: Seq[Attribute]
+
+ /**
+ * Initializes its aggregation buffer located in `buffer`.
+ * It will use bufferOffset to find the starting point of
+ * its buffer in the given `buffer` shared with other functions.
+ */
+ def initialize(buffer: MutableRow): Unit
+
+ /**
+ * Updates its aggregation buffer located in `buffer` based on the given `input`.
+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
+ * shared with other functions.
+ */
+ def update(buffer: MutableRow, input: InternalRow): Unit
+
+ /**
+ * Updates its aggregation buffer located in `buffer1` by combining intermediate results
+ * in the current buffer and intermediate results from another buffer `buffer2`.
+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
+ * and `buffer2`.
+ */
+ def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
+/**
+ * A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
+ */
+abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
+ self: Product =>
+
+ val initialValues: Seq[Expression]
+ val updateExpressions: Seq[Expression]
+ val mergeExpressions: Seq[Expression]
+ val evaluateExpression: Expression
+
+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+ /**
+ * A helper class for representing an attribute used in merging two
+ * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`,
+ * we merge buffer values and then update bufferLeft. A [[RichAttribute]]
+ * of an [[AttributeReference]] `a` has two functions `left` and `right`,
+ * which represent `a` in `bufferLeft` and `bufferRight`, respectively.
+ * @param a
+ */
+ implicit class RichAttribute(a: AttributeReference) {
+ /** Represents this attribute at the mutable buffer side. */
+ def left: AttributeReference = a
+
+ /** Represents this attribute at the input buffer side (the data value is read-only). */
+ def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
+ }
+
+ /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
+ override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
+
+ override def initialize(buffer: MutableRow): Unit = {
+ var i = 0
+ while (i < bufferAttributes.size) {
+ buffer(i + bufferOffset) = initialValues(i).eval()
+ i += 1
+ }
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's update should not be called directly")
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's merge should not be called directly")
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's eval should not be called directly")
+ }
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index d705a12860..e07c920a41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -27,7 +27,9 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
-trait AggregateExpression extends Expression with Unevaluable {
+trait AggregateExpression extends Expression with Unevaluable
+
+trait AggregateExpression1 extends AggregateExpression {
/**
* Aggregate expressions should not be foldable.
@@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable {
* Creates a new instance that can be used to compute this aggregate expression for a group
* of input rows/
*/
- def newInstance(): AggregateFunction
+ def newInstance(): AggregateFunction1
}
/**
@@ -54,10 +56,10 @@ case class SplitEvaluation(
partialEvaluations: Seq[NamedExpression])
/**
- * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
+ * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
-trait PartialAggregate extends AggregateExpression {
+trait PartialAggregate1 extends AggregateExpression1 {
/**
* Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
@@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression {
/**
* A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
+ * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
*/
-abstract class AggregateFunction
- extends LeafExpression with AggregateExpression with Serializable {
+abstract class AggregateFunction1
+ extends LeafExpression with AggregateExpression1 with Serializable {
/** Base should return the generic aggregate expression that this function is computing */
- val base: AggregateExpression
+ val base: AggregateExpression1
override def nullable: Boolean = base.nullable
override def dataType: DataType = base.dataType
@@ -81,12 +83,12 @@ abstract class AggregateFunction
def update(input: InternalRow): Unit
// Do we really need this?
- override def newInstance(): AggregateFunction = {
+ override def newInstance(): AggregateFunction1 = {
makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
}
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
-case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMin.value
}
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
-case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMax.value
}
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): CountFunction = new CountFunction(child, this)
}
-case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var count: Long = _
@@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = count
}
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
+case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
@@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
case class CountDistinctFunction(
@transient expr: Seq[Expression],
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -220,7 +222,7 @@ case class CountDistinctFunction(
override def eval(input: InternalRow): Any = seen.size.toLong
}
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
+case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
@@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
case class CollectHashSetFunction(
@transient expr: Seq[Expression],
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -255,7 +257,7 @@ case class CollectHashSetFunction(
}
}
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
+case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = inputSet :: Nil
@@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
case class CombineSetsAndCountFunction(
@transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression {
+ extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: DataType = HyperLogLogUDT
@@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
- base: AggregateExpression,
+ base: AggregateExpression1,
relativeSD: Double)
- extends AggregateFunction {
+ extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction(
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression {
+ extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
case class ApproxCountDistinctMergeFunction(
expr: Expression,
- base: AggregateExpression,
+ base: AggregateExpression1,
relativeSD: Double)
- extends AggregateFunction {
+ extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction(
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
- extends UnaryExpression with PartialAggregate {
+ extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def prettyName: String = "avg"
@@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
-case class AverageFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class AverageFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
@@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
-case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
private val calcType =
@@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
* <-- null <-- no data
* null <-- null <-- no data
*/
-case class CombineSum(child: Expression) extends AggregateExpression {
+case class CombineSum(child: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = child :: Nil
@@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression {
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
-case class CombineSumFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {
+case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
def this() = this(null)
override def nullable: Boolean = true
@@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
}
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
def this() = this(null, null)
override def children: Seq[Expression] = inputSet :: Nil
@@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg
case class CombineSetsAndSumFunction(
@transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction(
}
}
-case class First(child: Expression) extends UnaryExpression with PartialAggregate {
+case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST($child)"
@@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
-case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
@@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}
-case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate
override def newInstance(): LastFunction = new LastFunction(child, this)
}
-case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 03b4b3c216..d838268f46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import scala.collection.mutable.ArrayBuffer
@@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
- val evaluationCode = e.gen(ctx)
- evaluationCode.code +
- s"""
- if(${evaluationCode.isNull})
- mutableRow.setNullAt($i);
- else
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
- """
+ val projectionCode = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
+ val evaluationCode = e.gen(ctx)
+ evaluationCode.code +
+ s"""
+ if(${evaluationCode.isNull})
+ mutableRow.setNullAt($i);
+ else
+ ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
+ """
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 179a348d5b..b8e3b0d53a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,10 +129,10 @@ object PartialAggregation {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 986c315b31..6aefa9f675 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 78c780bdc5..1474b170ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -402,6 +402,9 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
+ defaultValue = Some(true), doc = "<TODO>")
+
val USE_SQL_SERIALIZER2 = booleanConf(
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
@@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
+ private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
+
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8b4528b5d5..49bfe74b68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
val udf: UDFRegistration = new UDFRegistration(this)
+ @transient
+ val udaf: UDAFRegistration = new UDAFRegistration(this)
+
/**
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
@@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
DDLStrategy ::
TakeOrderedAndProject ::
HashAggregation ::
+ Aggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
new file mode 100644
index 0000000000..5b872f5e3e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression}
+import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
+
+class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
+
+ private val functionRegistry = sqlContext.functionRegistry
+
+ def register(
+ name: String,
+ func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+ def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
+ functionRegistry.registerFunction(name, builder)
+ func
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 3cd60a2aa5..c2c945321d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -68,14 +68,14 @@ case class Aggregate(
* output.
*/
case class ComputedAggregate(
- unbound: AggregateExpression,
- aggregate: AggregateExpression,
+ unbound: AggregateExpression1,
+ aggregate: AggregateExpression1,
resultAttribute: AttributeReference)
/** A list of aggregates that need to be computed for each group. */
private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
agg.collect {
- case a: AggregateExpression =>
+ case a: AggregateExpression1 =>
ComputedAggregate(
a,
BindReferences.bindReference(a, child.output),
@@ -87,8 +87,8 @@ case class Aggregate(
private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
/** Creates a new aggregate buffer for a group. */
- private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
- val buffer = new Array[AggregateFunction](computedAggregates.length)
+ private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
+ val buffer = new Array[AggregateFunction1](computedAggregates.length)
var i = 0
while (i < computedAggregates.length) {
buffer(i) = computedAggregates(i).aggregate.newInstance()
@@ -146,7 +146,7 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
- val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
+ val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
var currentRow: InternalRow = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 2750053594..d31e265a29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
- if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+
+ if (rowOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
+ val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
+ if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+ } else {
+ child
+ }
} else {
child
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ecde9c5713..0e63f2fe29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -69,7 +69,7 @@ case class GeneratedAggregate(
protected override def doExecute(): RDD[InternalRow] = {
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
- a.collect { case agg: AggregateExpression => agg}
+ a.collect { case agg: AggregateExpression1 => agg}
}
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8cef7f200d..f54aa2027f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions)) &&
- codegenEnabled =>
+ codegenEnabled &&
+ !canBeConvertedToNewAggregation(plan) =>
execution.GeneratedAggregate(
partial = false,
namedGroupingAttributes,
@@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
- child) =>
+ child) if !canBeConvertedToNewAggregation(plan) =>
execution.Aggregate(
partial = false,
namedGroupingAttributes,
@@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
- def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
+ def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
+ aggregate.Utils.tryConvert(
+ plan,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled).isDefined
+ }
+
+ def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
@@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => true
}
- def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] =
- exprs.flatMap(_.collect { case a: AggregateExpression => a })
+ def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
+ exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}
+ /**
+ * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
+ */
+ object Aggregation extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case p: logical.Aggregate =>
+ val converted =
+ aggregate.Utils.tryConvert(
+ p,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled)
+ converted match {
+ case None => Nil // Cannot convert to new aggregation code path.
+ case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
+ // Extracts all distinct aggregate expressions from the resultExpressions.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ // For those distinct aggregate expressions, we create a map from the
+ // aggregate function to the corresponding attribute of the function.
+ val aggregateFunctionMap = aggregateExpressions.map { agg =>
+ val aggregateFunction = agg.aggregateFunction
+ (aggregateFunction, agg.isDistinct) ->
+ Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ }.toMap
+
+ val (functionsWithDistinct, functionsWithoutDistinct) =
+ aggregateExpressions.partition(_.isDistinct)
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ // This is a sanity check. We should not reach here when we have multiple distinct
+ // column sets (aggregate.NewAggregation will not match).
+ sys.error(
+ "Multiple distinct column sets are not supported by the new aggregation" +
+ "code path.")
+ }
+
+ val aggregateOperator =
+ if (functionsWithDistinct.isEmpty) {
+ aggregate.Utils.planAggregateWithoutDistinct(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateFunctionMap,
+ resultExpressions,
+ planLater(child))
+ } else {
+ aggregate.Utils.planAggregateWithOneDistinct(
+ groupingExpressions,
+ functionsWithDistinct,
+ functionsWithoutDistinct,
+ aggregateFunctionMap,
+ resultExpressions,
+ planLater(child))
+ }
+
+ aggregateOperator
+ }
+
+ case _ => Nil
+ }
+ }
+
+
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
@@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
- case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+ case a @ logical.Aggregate(group, agg, child) => {
+ val useNewAggregation =
+ aggregate.Utils.tryConvert(
+ a,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled).isDefined
+ if (useNewAggregation) {
+ // If this logical.Aggregate can be planned to use new aggregation code path
+ // (i.e. it can be planned by the Strategy Aggregation), we will not use the old
+ // aggregation code path.
+ Nil
+ } else {
+ execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+ }
+ }
case logical.Window(projectList, windowExpressions, spec, child) =>
execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
new file mode 100644
index 0000000000..0c9082897f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
+
+case class Aggregate2Sort(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def references: AttributeSet = {
+ val referencesInResults =
+ AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
+
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++
+ aggregateExpressions.flatMap(_.references) ++
+ referencesInResults)
+ }
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ // TODO: We should not sort the input rows if they are just in reversed order.
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ // It is possible that the child.outputOrdering starts with the required
+ // ordering expressions (e.g. we require [a] as the sort expression and the
+ // child's outputOrdering is [a, b]). We can only guarantee the output rows
+ // are sorted by values of groupingExpressions.
+ groupingExpressions.map(SortOrder(_, Ascending))
+ }
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ if (aggregateExpressions.length == 0) {
+ new GroupingIterator(
+ groupingExpressions,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ } else {
+ val aggregationIterator: SortAggregationIterator = {
+ aggregateExpressions.map(_.mode).distinct.toList match {
+ case Partial :: Nil =>
+ new PartialSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case PartialMerge :: Nil =>
+ new PartialMergeSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case Final :: Nil =>
+ new FinalSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case other =>
+ sys.error(
+ s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
+ s"modes $other in this operator.")
+ }
+ }
+
+ aggregationIterator
+ }
+ }
+ }
+}
+
+case class FinalAndCompleteAggregate2Sort(
+ previousGroupingExpressions: Seq[NamedExpression],
+ groupingExpressions: Seq[NamedExpression],
+ finalAggregateExpressions: Seq[AggregateExpression2],
+ finalAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+ override def references: AttributeSet = {
+ val referencesInResults =
+ AttributeSet(resultExpressions.flatMap(_.references)) --
+ AttributeSet(finalAggregateExpressions) --
+ AttributeSet(completeAggregateExpressions)
+
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++
+ finalAggregateExpressions.flatMap(_.references) ++
+ completeAggregateExpressions.flatMap(_.references) ++
+ referencesInResults)
+ }
+
+ override def requiredChildDistribution: List[Distribution] = {
+ if (groupingExpressions.isEmpty) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+
+ new FinalAndCompleteSortAggregationIterator(
+ previousGroupingExpressions.length,
+ groupingExpressions,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
new file mode 100644
index 0000000000..ce1cbdc9cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -0,0 +1,749 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.NullType
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An iterator used to evaluate aggregate functions. It assumes that input rows
+ * are already grouped by values of `groupingExpressions`.
+ */
+private[sql] abstract class SortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends Iterator[InternalRow] {
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Static fields for this iterator
+ ///////////////////////////////////////////////////////////////////////////
+
+ protected val aggregateFunctions: Array[AggregateFunction2] = {
+ var bufferOffset = initialBufferOffset
+ val functions = new Array[AggregateFunction2](aggregateExpressions.length)
+ var i = 0
+ while (i < aggregateExpressions.length) {
+ val func = aggregateExpressions(i).aggregateFunction
+ val funcWithBoundReferences = aggregateExpressions(i).mode match {
+ case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, inputAttributes)
+ case _ => func
+ }
+ // Set bufferOffset for this function. It is important that setting bufferOffset
+ // happens after all potential bindReference operations because bindReference
+ // will create a new instance of the function.
+ funcWithBoundReferences.bufferOffset = bufferOffset
+ bufferOffset += funcWithBoundReferences.bufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions.
+ protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ aggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // Positions of those non-algebraic aggregate functions in aggregateFunctions.
+ // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are non-algebraic aggregate functions.
+ // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+ protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ aggregateFunctions(i) match {
+ case agg: AlgebraicAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
+ }
+
+ // This is used to project expressions for the grouping expressions.
+ protected val groupGenerator =
+ newMutableProjection(groupingExpressions, inputAttributes)()
+
+ // The underlying buffer shared by all aggregate functions.
+ protected val buffer: MutableRow = {
+ // The number of elements of the underlying buffer of this operator.
+ // All aggregate functions are sharing this underlying buffer and they find their
+ // buffer values through bufferOffset.
+ var size = initialBufferOffset
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ size += aggregateFunctions(i).bufferSchema.length
+ i += 1
+ }
+ new GenericMutableRow(size)
+ }
+
+ protected val joinedRow = new JoinedRow4
+
+ protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
+
+ // This projection is used to initialize buffer values for all AlgebraicAggregates.
+ protected val algebraicInitialProjection = {
+ val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.initialValues
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(initExpressions, Nil)().target(buffer)
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Mutable states
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The partition key of the current partition.
+ protected var currentGroupingKey: InternalRow = _
+ // The partition key of next partition.
+ protected var nextGroupingKey: InternalRow = _
+ // The first row of next partition.
+ protected var firstRowInNextGroup: InternalRow = _
+ // Indicates if we has new group of rows to process.
+ protected var hasNewGroup: Boolean = true
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Private methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ /** Initializes buffer values for all aggregate functions. */
+ protected def initializeBuffer(): Unit = {
+ algebraicInitialProjection(EmptyRow)
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ protected def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+
+ /** Processes rows in the current group. It will stop when it find a new group. */
+ private def processCurrentGroup(): Unit = {
+ currentGroupingKey = nextGroupingKey
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(firstRowInNextGroup)
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ while (inputIter.hasNext && !findNextPartition) {
+ val currentRow = inputIter.next()
+ // Get the grouping key based on the grouping expressions.
+ // For the below compare method, we do not need to make a copy of groupingKey.
+ val groupingKey = groupGenerator(currentRow)
+ // Check if the current row belongs the current input row.
+ currentGroupingKey.equals(groupingKey)
+
+ if (currentGroupingKey == groupingKey) {
+ processRow(currentRow)
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ nextGroupingKey = groupingKey.copy()
+ firstRowInNextGroup = currentRow.copy()
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the iter.
+ if (!findNextPartition) {
+ hasNewGroup = false
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Public methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = hasNewGroup
+
+ override final def next(): InternalRow = {
+ if (hasNext) {
+ // Process the current group.
+ processCurrentGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput()
+ // Initilize buffer values for the next group.
+ initializeBuffer()
+
+ outputRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods that need to be implemented
+ ///////////////////////////////////////////////////////////////////////////
+
+ protected def initialBufferOffset: Int
+
+ protected def processRow(row: InternalRow): Unit
+
+ protected def generateOutput(): InternalRow
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Initialize this iterator
+ ///////////////////////////////////////////////////////////////////////////
+
+ initialize()
+}
+
+/**
+ * An iterator only used to group input rows according to values of `groupingExpressions`.
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ */
+class GroupingIterator(
+ groupingExpressions: Seq[NamedExpression],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ Nil,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ private val resultProjection =
+ newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
+
+ override protected def initialBufferOffset: Int = 0
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Since we only do grouping, there is nothing to do at here.
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ resultProjection(currentGroupingKey)
+ }
+}
+
+/**
+ * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // This projection is used to update buffer values for all AlgebraicAggregates.
+ private val algebraicUpdateProjection = {
+ val bufferSchema = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ }
+ val updateExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+ }
+
+ override protected def initialBufferOffset: Int = 0
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicUpdateProjection(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).update(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // We just output the grouping expressions and the underlying buffer.
+ joinedRow(currentGroupingKey, buffer).copy()
+ }
+}
+
+/**
+ * An iterator used to do partial merge aggregations (for those aggregate functions with mode
+ * PartialMerge). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialMergeSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ private val placeholderAttribtues =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata =
+ placeholderAttribtues ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to extract aggregation buffers from the underlying buffer.
+ // We need it because the underlying buffer has placeholders at its beginning.
+ private val extractsBufferValues = {
+ val expressions = aggregateFunctions.flatMap {
+ case agg => agg.bufferAttributes
+ }
+
+ newMutableProjection(expressions, inputAttributes)()
+ }
+
+ override protected def initialBufferOffset: Int = groupingExpressions.length
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // We output grouping expressions and aggregation buffers.
+ joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+ }
+}
+
+/**
+ * An iterator used to do final aggregations (for those aggregate functions with mode
+ * Final). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // The result of aggregate functions.
+ private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
+
+ // The projection used to generate the output rows of this operator.
+ // This is only used when we are generating final results of aggregate functions.
+ private val resultProjection =
+ newMutableProjection(
+ resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
+
+ private val offsetAttributes =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to evaluate all AlgebraicAggregates.
+ private val algebraicEvalProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val evalExpressions = aggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+
+ newMutableProjection(evalExpressions, bufferSchemata)()
+ }
+
+ override protected def initialBufferOffset: Int = groupingExpressions.length
+
+ override def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expression, we need to generate a single row as the output.
+ initializeBuffer()
+ // Right now, the buffer only contains initial buffer values. Because
+ // merging two buffers with initial values will generate a row that
+ // still store initial values. We set the currentRow as the copy of the current buffer.
+ val currentRow = buffer.copy()
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+ }
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(buffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ nonAlgebraicAggregateFunctionPositions(i),
+ nonAlgebraicAggregateFunctions(i).eval(buffer))
+ i += 1
+ }
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+ }
+}
+
+/**
+ * An iterator used to do both final aggregations (for those aggregate functions with mode
+ * Final) and complete aggregations (for those aggregate functions with mode Complete).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
+ * col1 to colM are columns used by aggregate functions with Complete mode.
+ * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
+ * Final mode.
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
+ * The first N placeholders represent slots of grouping expressions.
+ * Then, next M placeholders represent slots of col1 to colM.
+ * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
+ * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
+ * Complete. The reason that we have placeholders at here is to make our underlying buffer
+ * have the same length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalAndCompleteSortAggregationIterator(
+ override protected val initialBufferOffset: Int,
+ groupingExpressions: Seq[NamedExpression],
+ finalAggregateExpressions: Seq[AggregateExpression2],
+ finalAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ // TODO: document the ordering
+ finalAggregateExpressions ++ completeAggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // The result of aggregate functions.
+ private val aggregateResult: MutableRow =
+ new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
+
+ // The projection used to generate the output rows of this operator.
+ // This is only used when we are generating final results of aggregate functions.
+ private val resultProjection = {
+ val inputSchema =
+ groupingExpressions.map(_.toAttribute) ++
+ finalAggregateAttributes ++
+ completeAggregateAttributes
+ newMutableProjection(resultExpressions, inputSchema)()
+ }
+
+ private val offsetAttributes =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // All aggregate functions with mode Final.
+ private val finalAggregateFunctions: Array[AggregateFunction2] = {
+ val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
+ var i = 0
+ while (i < finalAggregateExpressions.length) {
+ functions(i) = aggregateFunctions(i)
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions with mode Final.
+ private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ finalAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // All aggregate functions with mode Complete.
+ private val completeAggregateFunctions: Array[AggregateFunction2] = {
+ val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
+ var i = 0
+ while (i < completeAggregateExpressions.length) {
+ functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions with mode Complete.
+ private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates with mode
+ // Final.
+ private val finalAlgebraicMergeProjection = {
+ val numCompleteOffsetAttributes =
+ completeAggregateFunctions.map(_.bufferAttributes.length).sum
+ val completeOffsetAttributes =
+ Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
+ val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
+
+ val bufferSchemata =
+ offsetAttributes ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ } ++ completeOffsetAttributes
+ val mergeExpressions =
+ placeholderExpressions ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to update buffer values for all AlgebraicAggregates with mode
+ // Complete.
+ private val completeAlgebraicUpdateProjection = {
+ val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
+ val finalOffsetAttributes =
+ Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
+ val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+
+ val bufferSchema =
+ offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ }
+ val updateExpressions =
+ placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+ }
+
+ // This projection is used to evaluate all AlgebraicAggregates.
+ private val algebraicEvalProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val evalExpressions = aggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+
+ newMutableProjection(evalExpressions, bufferSchemata)()
+ }
+
+ override def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expression, we need to generate a single row as the output.
+ initializeBuffer()
+ // Right now, the buffer only contains initial buffer values. Because
+ // merging two buffers with initial values will generate a row that
+ // still store initial values. We set the currentRow as the copy of the current buffer.
+ val currentRow = buffer.copy()
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+ }
+
+ override protected def processRow(row: InternalRow): Unit = {
+ val input = joinedRow(buffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
+ i += 1
+ }
+
+ // For all aggregate functions with mode Final, merge buffers.
+ finalAlgebraicMergeProjection.target(buffer)(input)
+ i = 0
+ while (i < finalNonAlgebraicAggregateFunctions.length) {
+ finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(buffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ nonAlgebraicAggregateFunctionPositions(i),
+ nonAlgebraicAggregateFunctions(i).eval(buffer))
+ i += 1
+ }
+
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
new file mode 100644
index 0000000000..1cb27710e0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+ val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+ case array: ArrayType => true
+ case map: MapType => true
+ case struct: StructType => true
+ case _ => false
+ }
+
+ !hasComplexTypes
+ }
+
+ private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate if supportsGroupingKeySchema(p) =>
+ val converted = p.transformExpressionsDown {
+ case expressions.Average(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Average(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Count(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ // We do not support multiple COUNT DISTINCT columns for now.
+ case expressions.CountDistinct(children) if children.length == 1 =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(children.head),
+ mode = aggregate.Complete,
+ isDistinct = true)
+
+ case expressions.First(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.First(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Last(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Last(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Max(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Max(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Min(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Min(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Sum(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.SumDistinct(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = true)
+ }
+ // Check if there is any expressions.AggregateExpression1 left.
+ // If so, we cannot convert this plan.
+ val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+ // For every expressions, check if it contains AggregateExpression1.
+ expr.find {
+ case agg: expressions.AggregateExpression1 => true
+ case other => false
+ }.isDefined
+ }
+
+ // Check if there are multiple distinct columns.
+ val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val hasMultipleDistinctColumnSets =
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ true
+ } else {
+ false
+ }
+
+ if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+ case other => None
+ }
+
+ private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+ // If the plan cannot be converted, we will do a final round check to if the original
+ // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+ // we need to throw an exception.
+ val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg.aggregateFunction
+ }
+ }.distinct
+ if (aggregateFunction2s.nonEmpty) {
+ // For functions implemented based on the new interface, prepare a list of function names.
+ val invalidFunctions = {
+ if (aggregateFunction2s.length > 1) {
+ s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+ s"and ${aggregateFunction2s.head.nodeName} are"
+ } else {
+ s"${aggregateFunction2s.head.nodeName} is"
+ }
+ }
+ val errorMessage =
+ s"${invalidFunctions} implemented based on the new Aggregate Function " +
+ s"interface and it cannot be used with functions implemented based on " +
+ s"the old Aggregate Function interface."
+ throw new AnalysisException(errorMessage)
+ }
+ }
+
+ def tryConvert(
+ plan: LogicalPlan,
+ useNewAggregation: Boolean,
+ codeGenEnabled: Boolean): Option[Aggregate] = plan match {
+ case p: Aggregate if useNewAggregation && codeGenEnabled =>
+ val converted = tryConvert(p)
+ if (converted.isDefined) {
+ converted
+ } else {
+ checkInvalidAggregateFunction2(p)
+ None
+ }
+ case p: Aggregate =>
+ checkInvalidAggregateFunction2(p)
+ None
+ case other => None
+ }
+
+ def planAggregateWithoutDistinct(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+ // 1. Create an Aggregate Operator for partial aggregations.
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+ val partialAggregateExpressions = aggregateExpressions.map {
+ case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ AggregateExpression2(aggregateFunction, Partial, isDistinct)
+ }
+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
+ }
+ val partialAggregate =
+ Aggregate2Sort(
+ None: Option[Seq[Expression]],
+ namedGroupingExpressions.map(_._2),
+ partialAggregateExpressions,
+ partialAggregateAttributes,
+ namedGroupingAttributes ++ partialAggregateAttributes,
+ child)
+
+ // 2. Create an Aggregate Operator for final aggregations.
+ val finalAggregateExpressions = aggregateExpressions.map {
+ case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ AggregateExpression2(aggregateFunction, Final, isDistinct)
+ }
+ val finalAggregateAttributes =
+ finalAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ val finalAggregate = Aggregate2Sort(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ rewrittenResultExpressions,
+ partialAggregate)
+
+ finalAggregate :: Nil
+ }
+
+ def planAggregateWithOneDistinct(
+ groupingExpressions: Seq[Expression],
+ functionsWithDistinct: Seq[AggregateExpression2],
+ functionsWithoutDistinct: Seq[AggregateExpression2],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+
+ // 1. Create an Aggregate Operator for partial aggregations.
+ // The grouping expressions are original groupingExpressions and
+ // distinct columns. For example, for avg(distinct value) ... group by key
+ // the grouping expressions of this Aggregate Operator will be [key, value].
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+ // It is safe to call head at here since functionsWithDistinct has at least one
+ // AggregateExpression2.
+ val distinctColumnExpressions =
+ functionsWithDistinct.head.aggregateFunction.children
+ val namedDistinctColumnExpressions = distinctColumnExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
+ val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+
+ val partialAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, Partial, false)
+ }
+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
+ }
+ val partialAggregate =
+ Aggregate2Sort(
+ None: Option[Seq[Expression]],
+ (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
+ partialAggregateExpressions,
+ partialAggregateAttributes,
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
+ child)
+
+ // 2. Create an Aggregate Operator for partial merge aggregations.
+ val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, PartialMerge, false)
+ }
+ val partialMergeAggregateAttributes =
+ partialMergeAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val partialMergeAggregate =
+ Aggregate2Sort(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes ++ distinctColumnAttributes,
+ partialMergeAggregateExpressions,
+ partialMergeAggregateAttributes,
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
+ partialAggregate)
+
+ // 3. Create an Aggregate Operator for partial merge aggregations.
+ val finalAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, Final, false)
+ }
+ val finalAggregateAttributes =
+ finalAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+ // Children of an AggregateFunction with DISTINCT keyword has already
+ // been evaluated. At here, we need to replace original children
+ // to AttributeReferences.
+ case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ val rewrittenAggregateFunction = aggregateFunction.transformDown {
+ case expr if distinctColumnExpressionMap.contains(expr) =>
+ distinctColumnExpressionMap(expr).toAttribute
+ }.asInstanceOf[AggregateFunction2]
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ val rewrittenAggregateExpression =
+ AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+
+ val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+ (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+ }.unzip
+
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
+ namedGroupingAttributes ++ distinctColumnAttributes,
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ rewrittenResultExpressions,
+ partialMergeAggregate)
+
+ finalAndCompleteAggregate :: Nil
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
new file mode 100644
index 0000000000..6c49a906c8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+
+/**
+ * The abstract class for implementing user-defined aggregate function.
+ */
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+ /**
+ * A [[StructType]] represents data types of input arguments of this aggregate function.
+ * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+ * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+ *
+ * ```
+ * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * input argument. Users can choose names to identify the input arguments.
+ */
+ def inputSchema: StructType
+
+ /**
+ * A [[StructType]] represents data types of values in the aggregation buffer.
+ * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+ * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+ * the returned [[StructType]] will look like
+ *
+ * ```
+ * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * buffer value. Users can choose names to identify the input arguments.
+ */
+ def bufferSchema: StructType
+
+ /**
+ * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+ */
+ def returnDataType: DataType
+
+ /** Indicates if this function is deterministic. */
+ def deterministic: Boolean
+
+ /**
+ * Initializes the given aggregation buffer. Initial values set by this method should satisfy
+ * the condition that when merging two buffers with initial values, the new buffer should
+ * still store initial values.
+ */
+ def initialize(buffer: MutableAggregationBuffer): Unit
+
+ /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+ /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+ /**
+ * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+ * aggregation buffer.
+ */
+ def evaluate(buffer: Row): Any
+}
+
+private[sql] abstract class AggregationBuffer(
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int)
+ extends Row {
+
+ override def length: Int = toCatalystConverters.length
+
+ protected val offsets: Array[Int] = {
+ val newOffsets = new Array[Int](length)
+ var i = 0
+ while (i < newOffsets.length) {
+ newOffsets(i) = bufferOffset + i
+ i += 1
+ }
+ newOffsets
+ }
+}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+class MutableAggregationBuffer private[sql] (
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingBuffer: MutableRow)
+ extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ toScalaConverters(i)(underlyingBuffer(offsets(i)))
+ }
+
+ def update(i: Int, value: Any): Unit = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not update ${i}th value in this buffer because it only has $length values.")
+ }
+ underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+ }
+
+ override def copy(): MutableAggregationBuffer = {
+ new MutableAggregationBuffer(
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingBuffer)
+ }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+class InputAggregationBuffer private[sql] (
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingInputBuffer: Row)
+ extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
+ }
+
+ override def copy(): InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingInputBuffer)
+ }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+case class ScalaUDAF(
+ children: Seq[Expression],
+ udaf: UserDefinedAggregateFunction)
+ extends AggregateFunction2 with Logging {
+
+ require(
+ children.length == udaf.inputSchema.length,
+ s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+ s"but ${children.length} are provided.")
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = udaf.returnDataType
+
+ override def deterministic: Boolean = udaf.deterministic
+
+ override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+ override val bufferSchema: StructType = udaf.bufferSchema
+
+ override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
+
+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+ val childrenSchema: StructType = {
+ val inputFields = children.zipWithIndex.map {
+ case (child, index) =>
+ StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
+ }
+ StructType(inputFields)
+ }
+
+ lazy val inputProjection = {
+ val inputAttributes = childrenSchema.toAttributes
+ log.debug(
+ s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+ try {
+ GenerateMutableProjection.generate(children, inputAttributes)()
+ } catch {
+ case e: Exception =>
+ log.error("Failed to generate mutable projection, fallback to interpreted", e)
+ new InterpretedMutableProjection(children, inputAttributes)
+ }
+ }
+
+ val inputToScalaConverters: Any => Any =
+ CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+ val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ }
+
+ val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToScalaConverter(field.dataType)
+ }
+
+ lazy val inputAggregateBuffer: InputAggregationBuffer =
+ new InputAggregationBuffer(
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+ lazy val mutableAggregateBuffer: MutableAggregationBuffer =
+ new MutableAggregationBuffer(
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+
+ override def initialize(buffer: MutableRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.initialize(mutableAggregateBuffer)
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.update(
+ mutableAggregateBuffer,
+ inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer1
+ inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+ udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+ }
+
+ override def eval(buffer: InternalRow = null): Any = {
+ inputAggregateBuffer.underlyingInputBuffer = buffer
+
+ udaf.evaluate(inputAggregateBuffer)
+ }
+
+ override def toString: String = {
+ s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+ }
+
+ override def nodeName: String = udaf.getClass.getSimpleName
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 28159cbd5a..bfeecbe8b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2420,7 +2420,7 @@ object functions {
* @since 1.5.0
*/
def callUDF(udfName: String, cols: Column*): Column = {
- UnresolvedFunction(udfName, cols.map(_.expr))
+ UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
}
/**
@@ -2449,7 +2449,7 @@ object functions {
exprs(i) = cols(i).expr
i += 1
}
- UnresolvedFunction(udfName, exprs)
+ UnresolvedFunction(udfName, exprs, isDistinct = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index beee10173f..ab8dce603c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
@@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
var hasGeneratedAgg = false
df.queryExecution.executedPlan.foreach {
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+ case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
case _ =>
}
if (!hasGeneratedAgg) {
@@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(0, null, 0) :: Nil)
+ Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3dd24130af..3d71deb13e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
@@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution}
class PlannerSuite extends SparkFunSuite {
+ private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
+ val planned =
+ plannedOption.getOrElse(
+ fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+ // For the new aggregation code path, there will be three aggregate operator for
+ // distinct aggregations.
+ assert(
+ aggregations.size == 2 || aggregations.size == 3,
+ s"The plan of query $query does not have partial aggregations.")
+ }
+
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
@@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite {
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
- val planned = HashAggregation(query).head
- val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
-
- assert(aggregations.size === 2)
+ testPartialAggregationPlan(query)
}
test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
- val planned = HashAggregation(query)
- assert(planned.nonEmpty)
+ testPartialAggregationPlan(query)
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
- val planned = HashAggregation(query)
- assert(planned.nonEmpty)
+ testPartialAggregationPlan(query)
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 31a49a3683..24a758f531 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite
"windowing_adjust_rowcontainer_sz"
)
+ // Only run those query tests in the realWhileList (do not try other ignored query files).
override def testCases: Seq[(String, File)] = super.testCases.filter {
case (name, _) => realWhiteList.contains(name)
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
index f458567e5d..1fe4fe9629 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.execution
+import java.io.File
+
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.test.TestHive
@@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
"join_reorder4",
"join_star"
)
+
+ // Only run those query tests in the realWhileList (do not try other ignored query files).
+ override def testCases: Seq[(String, File)] = super.testCases.filter {
+ case (name, _) => realWhiteList.contains(name)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cec7685bb6..4cdb83c511 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
DataSinks,
Scripts,
HashAggregation,
+ Aggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f5574509b0..8518e333e8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* UDFs - Must be last otherwise will preempt built in functions */
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr))
+ UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
+ // Aggregate function with DISTINCT keyword.
+ case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, UnresolvedStar(None) :: Nil)
+ UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
/* Literals */
case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4d23c7035c..3259b50acc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction(
private[hive] case class HiveGenericUDAF(
funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
+ children: Seq[Expression]) extends AggregateExpression1
with HiveInspectors {
type UDFType = AbstractGenericUDAFResolver
@@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF(
/** It is used as a wrapper for the hive functions which uses UDAF interface */
private[hive] case class HiveUDAF(
funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
+ children: Seq[Expression]) extends AggregateExpression1
with HiveInspectors {
type UDFType = UDAF
@@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF(
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
- base: AggregateExpression,
+ base: AggregateExpression1,
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction
+ extends AggregateFunction1
with HiveInspectors {
def this() = this(null, null, null)
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
new file mode 100644
index 0000000000..5c9d0e97a9
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class MyDoubleAvg extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleAvg() {
+ List<StructField> inputfields = new ArrayList<StructField>();
+ inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputfields);
+
+ List<StructField> bufferFields = new ArrayList<StructField>();
+ bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
+ bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType returnDataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, null);
+ buffer.update(1, 0L);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ if (buffer.isNullAt(0)) {
+ buffer.update(0, input.getDouble(0));
+ buffer.update(1, 1L);
+ } else {
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ buffer.update(1, buffer.getLong(1) + 1L);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ buffer1.update(0, buffer2.getDouble(0));
+ buffer1.update(1, buffer2.getLong(1));
+ } else {
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ return null;
+ } else {
+ return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
+ }
+ }
+}
+
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
new file mode 100644
index 0000000000..1d4587a27c
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.Row;
+
+public class MyDoubleSum extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleSum() {
+ List<StructField> inputfields = new ArrayList<StructField>();
+ inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputfields);
+
+ List<StructField> bufferFields = new ArrayList<StructField>();
+ bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType returnDataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, null);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ if (buffer.isNullAt(0)) {
+ buffer.update(0, input.getDouble(0));
+ } else {
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ buffer1.update(0, buffer2.getDouble(0));
+ } else {
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ return null;
+ } else {
+ return buffer.getDouble(0);
+ }
+ }
+}
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
new file mode 100644
index 0000000000..573541ac97
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
new file mode 100644
index 0000000000..44b2a42cc2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
@@ -0,0 +1 @@
+unhex(str) - Converts hexadecimal argument to binary
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
new file mode 100644
index 0000000000..97af3b812a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
@@ -0,0 +1,14 @@
+unhex(str) - Converts hexadecimal argument to binary
+Performs the inverse operation of HEX(str). That is, it interprets
+each pair of hexadecimal digits in the argument as a number and
+converts it to the byte representation of the number. The
+resulting characters are returned as a binary string.
+
+Example:
+> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1;
+'MySQL'
+
+The characters in the argument string must be legal hexadecimal
+digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters
+any nonhexadecimal digits in the argument, it returns NULL. Also,
+if there are an odd number of characters a leading 0 is appended.
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
new file mode 100644
index 0000000000..b4a6f2b692
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
@@ -0,0 +1 @@
+MySQL 1267 a -4
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
new file mode 100644
index 0000000000..3a67adaf0a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
@@ -0,0 +1 @@
+NULL NULL NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
new file mode 100644
index 0000000000..0375eb79ad
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.scalatest.BeforeAndAfterAll
+import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+
+class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
+
+ override val sqlContext = TestHive
+ import sqlContext.implicits._
+
+ var originalUseAggregate2: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
+ sqlContext.sql("set spark.sql.useAggregate2=true")
+ val data1 = Seq[(Integer, Integer)](
+ (1, 10),
+ (null, -60),
+ (1, 20),
+ (1, 30),
+ (2, 0),
+ (null, -10),
+ (2, -1),
+ (2, null),
+ (2, null),
+ (null, 100),
+ (3, null),
+ (null, null),
+ (3, null)).toDF("key", "value")
+ data1.write.saveAsTable("agg1")
+
+ val data2 = Seq[(Integer, Integer, Integer)](
+ (1, 10, -10),
+ (null, -60, 60),
+ (1, 30, -30),
+ (1, 30, 30),
+ (2, 1, 1),
+ (null, -10, 10),
+ (2, -1, null),
+ (2, 1, 1),
+ (2, null, 1),
+ (null, 100, -10),
+ (3, null, 3),
+ (null, null, null),
+ (3, null, null)).toDF("key", "value1", "value2")
+ data2.write.saveAsTable("agg2")
+
+ val emptyDF = sqlContext.createDataFrame(
+ sqlContext.sparkContext.emptyRDD[Row],
+ StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
+ emptyDF.registerTempTable("emptyTable")
+
+ // Register UDAFs
+ sqlContext.udaf.register("mydoublesum", new MyDoubleSum)
+ sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg)
+ }
+
+ override def afterAll(): Unit = {
+ sqlContext.sql("DROP TABLE IF EXISTS agg1")
+ sqlContext.sql("DROP TABLE IF EXISTS agg2")
+ sqlContext.dropTempTable("emptyTable")
+ sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
+ }
+
+ test("empty table") {
+ // If there is no GROUP BY clause and the table is empty, we will generate a single row.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(key),
+ | COUNT(value),
+ | FIRST(key),
+ | LAST(value),
+ | MAX(key),
+ | MIN(value),
+ | SUM(key)
+ |FROM emptyTable
+ """.stripMargin),
+ Row(null, 0, 0, 0, null, null, null, null, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(key),
+ | COUNT(value),
+ | FIRST(key),
+ | LAST(value),
+ | MAX(key),
+ | MIN(value),
+ | SUM(key),
+ | COUNT(DISTINCT value)
+ |FROM emptyTable
+ """.stripMargin),
+ Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil)
+
+ // If there is a GROUP BY clause and the table is empty, there is no output.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(value),
+ | FIRST(value),
+ | LAST(value),
+ | MAX(value),
+ | MIN(value),
+ | SUM(value),
+ | COUNT(DISTINCT value)
+ |FROM emptyTable
+ |GROUP BY key
+ """.stripMargin),
+ Nil)
+ }
+
+ test("only do grouping") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT DISTINCT value1, key
+ |FROM agg2
+ """.stripMargin),
+ Row(10, 1) ::
+ Row(-60, null) ::
+ Row(30, 1) ::
+ Row(1, 2) ::
+ Row(-10, null) ::
+ Row(-1, 2) ::
+ Row(null, 2) ::
+ Row(100, null) ::
+ Row(null, 3) ::
+ Row(null, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT value1, key
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(10, 1) ::
+ Row(-60, null) ::
+ Row(30, 1) ::
+ Row(1, 2) ::
+ Row(-10, null) ::
+ Row(-1, 2) ::
+ Row(null, 2) ::
+ Row(100, null) ::
+ Row(null, 3) ::
+ Row(null, null) :: Nil)
+ }
+
+ test("case in-sensitive resolution") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value), kEY - 100
+ |FROM agg1
+ |GROUP BY Key - 100
+ """.stripMargin),
+ Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT sum(distinct value1), kEY - 100, count(distinct value1)
+ |FROM agg2
+ |GROUP BY Key - 100
+ """.stripMargin),
+ Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT valUe * key - 100
+ |FROM agg1
+ |GROUP BY vAlue * keY - 100
+ """.stripMargin),
+ Row(-90) ::
+ Row(-80) ::
+ Row(-70) ::
+ Row(-100) ::
+ Row(-102) ::
+ Row(null) :: Nil)
+ }
+
+ test("test average no key in output") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil)
+ }
+
+ test("test average") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key, avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value), key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value) + 1.5, key + 10
+ |FROM agg1
+ |GROUP BY key + 10
+ """.stripMargin),
+ Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value) FROM agg1
+ """.stripMargin),
+ Row(11.125) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(null)
+ """.stripMargin),
+ Row(null) :: Nil)
+ }
+
+ test("udaf") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoublesum(value + 1.5 * key),
+ | mydoubleavg(value),
+ | avg(value - key),
+ | mydoublesum(value - 1.5 * key),
+ | avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) ::
+ Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) ::
+ Row(3, null, null, null, null, null) ::
+ Row(null, null, 110.0, null, null, 10.0) :: Nil)
+ }
+
+ test("non-AlgebraicAggregate aggreguate function") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value), key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value) FROM agg1
+ """.stripMargin),
+ Row(89.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(null)
+ """.stripMargin),
+ Row(null) :: Nil)
+ }
+
+ test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value), key, avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(60.0, 1, 20.0) ::
+ Row(-1.0, 2, -0.5) ::
+ Row(null, 3, null) ::
+ Row(30.0, null, 10.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoublesum(value + 1.5 * key),
+ | avg(value - key),
+ | key,
+ | mydoublesum(value - 1.5 * key),
+ | avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(64.5, 19.0, 1, 55.5, 20.0) ::
+ Row(5.0, -2.5, 2, -7.0, -0.5) ::
+ Row(null, null, 3, null, null) ::
+ Row(null, null, null, null, 10.0) :: Nil)
+ }
+
+ test("single distinct column set") {
+ // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | min(distinct value1),
+ | sum(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | max(distinct value1)
+ |FROM agg2
+ """.stripMargin),
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoubleavg(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | key,
+ | mydoubleavg(value1 - 1),
+ | mydoubleavg(distinct value1) * 0.1,
+ | avg(value1 + value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+ Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+ Row(null, null, 3.0, 3, null, null, null) ::
+ Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoubleavg(distinct value1),
+ | mydoublesum(value2),
+ | mydoublesum(distinct value1),
+ | mydoubleavg(distinct value1),
+ | mydoubleavg(value1)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+ Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+ Row(3, null, 3.0, null, null, null) ::
+ Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+ }
+
+ test("test count") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value2),
+ | value1,
+ | count(*),
+ | count(1),
+ | key
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(1, 10, 1, 1, 1) ::
+ Row(1, -60, 1, 1, null) ::
+ Row(2, 30, 2, 2, 1) ::
+ Row(2, 1, 2, 2, 2) ::
+ Row(1, -10, 1, 1, null) ::
+ Row(0, -1, 1, 1, 2) ::
+ Row(1, null, 1, 1, 2) ::
+ Row(1, 100, 1, 1, null) ::
+ Row(1, null, 2, 2, 3) ::
+ Row(0, null, 1, 1, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value2),
+ | value1,
+ | count(*),
+ | count(1),
+ | key,
+ | count(DISTINCT abs(value2))
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(1, 10, 1, 1, 1, 1) ::
+ Row(1, -60, 1, 1, null, 1) ::
+ Row(2, 30, 2, 2, 1, 1) ::
+ Row(2, 1, 2, 2, 2, 1) ::
+ Row(1, -10, 1, 1, null, 1) ::
+ Row(0, -1, 1, 1, 2, 0) ::
+ Row(1, null, 1, 1, 2, 1) ::
+ Row(1, 100, 1, 1, null, 1) ::
+ Row(1, null, 2, 2, 3, 1) ::
+ Row(0, null, 1, 1, null, 0) :: Nil)
+ }
+
+ test("error handling") {
+ sqlContext.sql(s"set spark.sql.useAggregate2=false")
+ var errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | sum(value + 1.5 * key),
+ | mydoublesum(value),
+ | mydoubleavg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+ // TODO: once we support Hive UDAF in the new interface,
+ // we can remove the following two tests.
+ sqlContext.sql(s"set spark.sql.useAggregate2=true")
+ errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoublesum(value + 1.5 * key),
+ | stddev_samp(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+ // This will fall back to the old aggregate
+ val newAggregateOperators = sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | sum(value + 1.5 * key),
+ | stddev_samp(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).queryExecution.executedPlan.collect {
+ case agg: Aggregate2Sort => agg
+ }
+ val message =
+ "We should fallback to the old aggregation code path if there is any aggregate function " +
+ "that cannot be converted to the new interface."
+ assert(newAggregateOperators.isEmpty, message)
+
+ sqlContext.sql(s"set spark.sql.useAggregate2=true")
+ }
+}