aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-07-21 23:26:11 -0700
committerReynold Xin <rxin@databricks.com>2015-07-21 23:26:11 -0700
commitc03299a18b4e076cabb4b7833a1e7632c5c0dabe (patch)
treef35f52623f9abb16ff260a3e4832327dc242f40a
parentf4785f5b82c57bce41d3dc26ed9e3c9e794c7558 (diff)
downloadspark-c03299a18b4e076cabb4b7833a1e7632c5c0dabe.tar.gz
spark-c03299a18b4e076cabb4b7833a1e7632c5c0dabe.tar.bz2
spark-c03299a18b4e076cabb4b7833a1e7632c5c0dabe.zip
[SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement
This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367. This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true. This new code path contains: * A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`) * A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`). * A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface . * A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together). With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is: 1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface. 2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true: * code-gen is disabled. * there is any function that cannot be converted (right now, Hive UDAFs). * the schema of grouping expressions contain any complex data type. * There are multiple distinct columns. Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function. Thanks chenghao-intel for his initial work on it. Author: Yin Huai <yhuai@databricks.com> Author: Michael Armbrust <michael@databricks.com> Closes #7458 from yhuai/UDAF and squashes the following commits: 7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it. b04d6c8 [Yin Huai] Remove unnecessary change. f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator. 3b43b24 [Yin Huai] bug fix. 00eb298 [Yin Huai] Make it compile. a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path. 8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose. dc96fd1 [Yin Huai] Many updates: 85c9c4b [Yin Huai] newline. 43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF c3614d7 [Yin Huai] Handle single distinct column. 68b8ee9 [Yin Huai] Support single distinct column set. WIP 3013579 [Yin Huai] Format. d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it. e243ca6 [Yin Huai] Add aggregation iterators. a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum. 594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2. 380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places. a19fea6 [Yin Huai] Add UDAF interface. 262d4c4 [Yin Huai] Make it compile. b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 6edb5ac [Yin Huai] Format update. 70b169c [Yin Huai] Remove groupOrdering. 4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules. d821a34 [Yin Huai] Cleanup. 32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 5b46d41 [Yin Huai] Bug fix. aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions. 2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer. 1b490ed [Michael Armbrust] make hive test 8cfa6a9 [Michael Armbrust] add test 1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places. 072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references. f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF 39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences. b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version. 5c00f3f [Michael Armbrust] First draft of codegen 6bbc6ba [Michael Armbrust] now with correct answers\! f7996d0 [Michael Armbrust] Add AlgebraicAggregate dded1c5 [Yin Huai] wip
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala292
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala206
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala100
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala173
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala749
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala364
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala280
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala26
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala1
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala8
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java107
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java100
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada1
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc471
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae414
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e1
-rw-r--r--sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c31
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala507
39 files changed, 3087 insertions, 100 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index d4ef04c229..c04bd6cd85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
- { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
+ case name => UnresolvedFunction(name, exprs, isDistinct = true)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e58f3f6494..8cadbc57e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -277,7 +278,7 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
- case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
@@ -517,9 +518,26 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
- case u @ UnresolvedFunction(name, children) =>
+ case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
- registry.lookupFunction(name, children)
+ registry.lookupFunction(name, children) match {
+ // We get an aggregate function built based on AggregateFunction2 interface.
+ // So, we wrap it in AggregateExpression2.
+ case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
+ // Currently, our old aggregate function interface supports SUM(DISTINCT ...)
+ // and COUTN(DISTINCT ...).
+ case sumDistinct: SumDistinct => sumDistinct
+ case countDistinct: CountDistinct => countDistinct
+ // DISTINCT is not meaningful with Max and Min.
+ case max: Max if isDistinct => max
+ case min: Min if isDistinct => min
+ // For other aggregate functions, DISTINCT keyword is not supported for now.
+ // Once we converted to the new code path, we will allow using DISTINCT keyword.
+ case other if isDistinct =>
+ failAnalysis(s"$name does not support DISTINCT keyword.")
+ // If it does not have DISTINCT keyword, we will return it as is.
+ case other => other
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c7f9713344..c203fcecf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 0daee1990a..03da45b09f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}
-case class UnresolvedFunction(name: String, children: Seq[Expression])
+case class UnresolvedFunction(
+ name: String,
+ children: Seq[Expression],
+ isDistinct: Boolean)
extends Expression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b09aea0331..b10a3c8774 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression with NamedExpression {
- override def toString: String = s"input[$ordinal]"
+ override def toString: String = s"input[$ordinal, $dataType]"
override def eval(input: InternalRow): Any = input(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index aada25276a..29ae47e842 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
- ve
+ // Add `this` in the comment.
+ ve.copy(s"/* $this */\n" + ve.code)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
new file mode 100644
index 0000000000..b924af4cc8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+case class Average(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = resultType
+
+ // Expected input data type.
+ // TODO: Once we remove the old code path, we can use our analyzer to cast NullType
+ // to the default data type of the NumericType.
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+ private val resultType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited => DecimalType.Unlimited
+ case _ => DoubleType
+ }
+
+ private val sumDataType = child.dataType match {
+ case _ @ DecimalType() => DecimalType.Unlimited
+ case _ => DoubleType
+ }
+
+ private val currentSum = AttributeReference("currentSum", sumDataType)()
+ private val currentCount = AttributeReference("currentCount", LongType)()
+
+ override val bufferAttributes = currentSum :: currentCount :: Nil
+
+ override val initialValues = Seq(
+ /* currentSum = */ Cast(Literal(0), sumDataType),
+ /* currentCount = */ Literal(0L)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentSum = */
+ Add(
+ currentSum,
+ Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+ /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ )
+
+ override val mergeExpressions = Seq(
+ /* currentSum = */ currentSum.left + currentSum.right,
+ /* currentCount = */ currentCount.left + currentCount.right
+ )
+
+ // If all input are nulls, currentCount will be 0 and we will get null after the division.
+ override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
+}
+
+case class Count(child: Expression) extends AlgebraicAggregate {
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = LongType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val currentCount = AttributeReference("currentCount", LongType)()
+
+ override val bufferAttributes = currentCount :: Nil
+
+ override val initialValues = Seq(
+ /* currentCount = */ Literal(0L)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ )
+
+ override val mergeExpressions = Seq(
+ /* currentCount = */ currentCount.left + currentCount.right
+ )
+
+ override val evaluateExpression = Cast(currentCount, LongType)
+}
+
+case class First(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // First is not a deterministic function.
+ override def deterministic: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val first = AttributeReference("first", child.dataType)()
+
+ override val bufferAttributes = first :: Nil
+
+ override val initialValues = Seq(
+ /* first = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* first = */ If(IsNull(first), child, first)
+ )
+
+ override val mergeExpressions = Seq(
+ /* first = */ If(IsNull(first.left), first.right, first.left)
+ )
+
+ override val evaluateExpression = first
+}
+
+case class Last(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Last is not a deterministic function.
+ override def deterministic: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val last = AttributeReference("last", child.dataType)()
+
+ override val bufferAttributes = last :: Nil
+
+ override val initialValues = Seq(
+ /* last = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* last = */ If(IsNull(child), last, child)
+ )
+
+ override val mergeExpressions = Seq(
+ /* last = */ If(IsNull(last.right), last.left, last.right)
+ )
+
+ override val evaluateExpression = last
+}
+
+case class Max(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val max = AttributeReference("max", child.dataType)()
+
+ override val bufferAttributes = max :: Nil
+
+ override val initialValues = Seq(
+ /* max = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
+ )
+
+ override val mergeExpressions = {
+ val greatest = Greatest(Seq(max.left, max.right))
+ Seq(
+ /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
+ )
+ }
+
+ override val evaluateExpression = max
+}
+
+case class Min(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = child.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ private val min = AttributeReference("min", child.dataType)()
+
+ override val bufferAttributes = min :: Nil
+
+ override val initialValues = Seq(
+ /* min = */ Literal.create(null, child.dataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
+ )
+
+ override val mergeExpressions = {
+ val least = Least(Seq(min.left, min.right))
+ Seq(
+ /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
+ )
+ }
+
+ override val evaluateExpression = min
+}
+
+case class Sum(child: Expression) extends AlgebraicAggregate {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = resultType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+
+ private val resultType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited => DecimalType.Unlimited
+ case _ => child.dataType
+ }
+
+ private val sumDataType = child.dataType match {
+ case _ @ DecimalType() => DecimalType.Unlimited
+ case _ => child.dataType
+ }
+
+ private val currentSum = AttributeReference("currentSum", sumDataType)()
+
+ private val zero = Cast(Literal(0), sumDataType)
+
+ override val bufferAttributes = currentSum :: Nil
+
+ override val initialValues = Seq(
+ /* currentSum = */ Literal.create(null, sumDataType)
+ )
+
+ override val updateExpressions = Seq(
+ /* currentSum = */
+ Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
+ )
+
+ override val mergeExpressions = {
+ val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
+ Seq(
+ /* currentSum = */
+ Coalesce(Seq(add, currentSum.left))
+ )
+ }
+
+ override val evaluateExpression = Cast(currentSum, resultType)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
new file mode 100644
index 0000000000..577ede73cb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/** The mode of an [[AggregateFunction1]]. */
+private[sql] sealed trait AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object Partial extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object PartialMerge extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function and the generate final result.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Final extends AggregateMode
+
+/**
+ * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * from original input rows without any partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Complete extends AggregateMode
+
+/**
+ * A place holder expressions used in code-gen, it does not change the corresponding value
+ * in the row.
+ */
+private[sql] case object NoOp extends Expression with Unevaluable {
+ override def nullable: Boolean = true
+ override def eval(input: InternalRow): Any = {
+ throw new TreeNodeException(
+ this, s"No function to evaluate expression. type: ${this.nodeName}")
+ }
+ override def dataType: DataType = NullType
+ override def children: Seq[Expression] = Nil
+}
+
+/**
+ * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
+ * @param aggregateFunction
+ * @param mode
+ * @param isDistinct
+ */
+private[sql] case class AggregateExpression2(
+ aggregateFunction: AggregateFunction2,
+ mode: AggregateMode,
+ isDistinct: Boolean) extends AggregateExpression {
+
+ override def children: Seq[Expression] = aggregateFunction :: Nil
+ override def dataType: DataType = aggregateFunction.dataType
+ override def foldable: Boolean = false
+ override def nullable: Boolean = aggregateFunction.nullable
+
+ override def references: AttributeSet = {
+ val childReferemces = mode match {
+ case Partial | Complete => aggregateFunction.references.toSeq
+ case PartialMerge | Final => aggregateFunction.bufferAttributes
+ }
+
+ AttributeSet(childReferemces)
+ }
+
+ override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
+}
+
+abstract class AggregateFunction2
+ extends Expression with ImplicitCastInputTypes {
+
+ self: Product =>
+
+ /** An aggregate function is not foldable. */
+ override def foldable: Boolean = false
+
+ /**
+ * The offset of this function's buffer in the underlying buffer shared with other functions.
+ */
+ var bufferOffset: Int = 0
+
+ /** The schema of the aggregation buffer. */
+ def bufferSchema: StructType
+
+ /** Attributes of fields in bufferSchema. */
+ def bufferAttributes: Seq[AttributeReference]
+
+ /** Clones bufferAttributes. */
+ def cloneBufferAttributes: Seq[Attribute]
+
+ /**
+ * Initializes its aggregation buffer located in `buffer`.
+ * It will use bufferOffset to find the starting point of
+ * its buffer in the given `buffer` shared with other functions.
+ */
+ def initialize(buffer: MutableRow): Unit
+
+ /**
+ * Updates its aggregation buffer located in `buffer` based on the given `input`.
+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
+ * shared with other functions.
+ */
+ def update(buffer: MutableRow, input: InternalRow): Unit
+
+ /**
+ * Updates its aggregation buffer located in `buffer1` by combining intermediate results
+ * in the current buffer and intermediate results from another buffer `buffer2`.
+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
+ * and `buffer2`.
+ */
+ def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
+/**
+ * A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
+ */
+abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
+ self: Product =>
+
+ val initialValues: Seq[Expression]
+ val updateExpressions: Seq[Expression]
+ val mergeExpressions: Seq[Expression]
+ val evaluateExpression: Expression
+
+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+ /**
+ * A helper class for representing an attribute used in merging two
+ * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`,
+ * we merge buffer values and then update bufferLeft. A [[RichAttribute]]
+ * of an [[AttributeReference]] `a` has two functions `left` and `right`,
+ * which represent `a` in `bufferLeft` and `bufferRight`, respectively.
+ * @param a
+ */
+ implicit class RichAttribute(a: AttributeReference) {
+ /** Represents this attribute at the mutable buffer side. */
+ def left: AttributeReference = a
+
+ /** Represents this attribute at the input buffer side (the data value is read-only). */
+ def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
+ }
+
+ /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
+ override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
+
+ override def initialize(buffer: MutableRow): Unit = {
+ var i = 0
+ while (i < bufferAttributes.size) {
+ buffer(i + bufferOffset) = initialValues(i).eval()
+ i += 1
+ }
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's update should not be called directly")
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's merge should not be called directly")
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ throw new UnsupportedOperationException(
+ "AlgebraicAggregate's eval should not be called directly")
+ }
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index d705a12860..e07c920a41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -27,7 +27,9 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
-trait AggregateExpression extends Expression with Unevaluable {
+trait AggregateExpression extends Expression with Unevaluable
+
+trait AggregateExpression1 extends AggregateExpression {
/**
* Aggregate expressions should not be foldable.
@@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable {
* Creates a new instance that can be used to compute this aggregate expression for a group
* of input rows/
*/
- def newInstance(): AggregateFunction
+ def newInstance(): AggregateFunction1
}
/**
@@ -54,10 +56,10 @@ case class SplitEvaluation(
partialEvaluations: Seq[NamedExpression])
/**
- * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
+ * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
-trait PartialAggregate extends AggregateExpression {
+trait PartialAggregate1 extends AggregateExpression1 {
/**
* Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
@@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression {
/**
* A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
+ * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
*/
-abstract class AggregateFunction
- extends LeafExpression with AggregateExpression with Serializable {
+abstract class AggregateFunction1
+ extends LeafExpression with AggregateExpression1 with Serializable {
/** Base should return the generic aggregate expression that this function is computing */
- val base: AggregateExpression
+ val base: AggregateExpression1
override def nullable: Boolean = base.nullable
override def dataType: DataType = base.dataType
@@ -81,12 +83,12 @@ abstract class AggregateFunction
def update(input: InternalRow): Unit
// Do we really need this?
- override def newInstance(): AggregateFunction = {
+ override def newInstance(): AggregateFunction1 = {
makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
}
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
-case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMin.value
}
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
-case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMax.value
}
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): CountFunction = new CountFunction(child, this)
}
-case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var count: Long = _
@@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = count
}
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
+case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
@@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
case class CountDistinctFunction(
@transient expr: Seq[Expression],
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -220,7 +222,7 @@ case class CountDistinctFunction(
override def eval(input: InternalRow): Any = seen.size.toLong
}
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
+case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
@@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
case class CollectHashSetFunction(
@transient expr: Seq[Expression],
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -255,7 +257,7 @@ case class CollectHashSetFunction(
}
}
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
+case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = inputSet :: Nil
@@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
case class CombineSetsAndCountFunction(
@transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression {
+ extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: DataType = HyperLogLogUDT
@@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
- base: AggregateExpression,
+ base: AggregateExpression1,
relativeSD: Double)
- extends AggregateFunction {
+ extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction(
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression {
+ extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
case class ApproxCountDistinctMergeFunction(
expr: Expression,
- base: AggregateExpression,
+ base: AggregateExpression1,
relativeSD: Double)
- extends AggregateFunction {
+ extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction(
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
- extends UnaryExpression with PartialAggregate {
+ extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
@@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def prettyName: String = "avg"
@@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
-case class AverageFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class AverageFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
@@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
-case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
private val calcType =
@@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
* <-- null <-- no data
* null <-- null <-- no data
*/
-case class CombineSum(child: Expression) extends AggregateExpression {
+case class CombineSum(child: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = child :: Nil
@@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression {
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
-case class CombineSumFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {
+case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
def this() = this(null)
override def nullable: Boolean = true
@@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
}
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
- extends AggregateFunction {
+case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
def this() = this(null, null)
override def children: Seq[Expression] = inputSet :: Nil
@@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg
case class CombineSetsAndSumFunction(
@transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
+ @transient base: AggregateExpression1)
+ extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
@@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction(
}
}
-case class First(child: Expression) extends UnaryExpression with PartialAggregate {
+case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST($child)"
@@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
-case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
@@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}
-case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate
override def newInstance(): LastFunction = new LastFunction(child, this)
}
-case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 03b4b3c216..d838268f46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import scala.collection.mutable.ArrayBuffer
@@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
- val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
- val evaluationCode = e.gen(ctx)
- evaluationCode.code +
- s"""
- if(${evaluationCode.isNull})
- mutableRow.setNullAt($i);
- else
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
- """
+ val projectionCode = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
+ val evaluationCode = e.gen(ctx)
+ evaluationCode.code +
+ s"""
+ if(${evaluationCode.isNull})
+ mutableRow.setNullAt($i);
+ else
+ ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
+ """
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 179a348d5b..b8e3b0d53a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,10 +129,10 @@ object PartialAggregation {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 986c315b31..6aefa9f675 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 78c780bdc5..1474b170ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -402,6 +402,9 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
+ defaultValue = Some(true), doc = "<TODO>")
+
val USE_SQL_SERIALIZER2 = booleanConf(
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
@@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
+ private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
+
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8b4528b5d5..49bfe74b68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
val udf: UDFRegistration = new UDFRegistration(this)
+ @transient
+ val udaf: UDAFRegistration = new UDAFRegistration(this)
+
/**
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
@@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
DDLStrategy ::
TakeOrderedAndProject ::
HashAggregation ::
+ Aggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
new file mode 100644
index 0000000000..5b872f5e3e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression}
+import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
+
+class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
+
+ private val functionRegistry = sqlContext.functionRegistry
+
+ def register(
+ name: String,
+ func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+ def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
+ functionRegistry.registerFunction(name, builder)
+ func
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 3cd60a2aa5..c2c945321d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -68,14 +68,14 @@ case class Aggregate(
* output.
*/
case class ComputedAggregate(
- unbound: AggregateExpression,
- aggregate: AggregateExpression,
+ unbound: AggregateExpression1,
+ aggregate: AggregateExpression1,
resultAttribute: AttributeReference)
/** A list of aggregates that need to be computed for each group. */
private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
agg.collect {
- case a: AggregateExpression =>
+ case a: AggregateExpression1 =>
ComputedAggregate(
a,
BindReferences.bindReference(a, child.output),
@@ -87,8 +87,8 @@ case class Aggregate(
private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
/** Creates a new aggregate buffer for a group. */
- private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
- val buffer = new Array[AggregateFunction](computedAggregates.length)
+ private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
+ val buffer = new Array[AggregateFunction1](computedAggregates.length)
var i = 0
while (i < computedAggregates.length) {
buffer(i) = computedAggregates(i).aggregate.newInstance()
@@ -146,7 +146,7 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
- val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
+ val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
var currentRow: InternalRow = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 2750053594..d31e265a29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
- if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+
+ if (rowOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
+ val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
+ if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+ } else {
+ child
+ }
} else {
child
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ecde9c5713..0e63f2fe29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -69,7 +69,7 @@ case class GeneratedAggregate(
protected override def doExecute(): RDD[InternalRow] = {
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
- a.collect { case agg: AggregateExpression => agg}
+ a.collect { case agg: AggregateExpression1 => agg}
}
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8cef7f200d..f54aa2027f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions)) &&
- codegenEnabled =>
+ codegenEnabled &&
+ !canBeConvertedToNewAggregation(plan) =>
execution.GeneratedAggregate(
partial = false,
namedGroupingAttributes,
@@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
- child) =>
+ child) if !canBeConvertedToNewAggregation(plan) =>
execution.Aggregate(
partial = false,
namedGroupingAttributes,
@@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
- def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
+ def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
+ aggregate.Utils.tryConvert(
+ plan,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled).isDefined
+ }
+
+ def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
@@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => true
}
- def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] =
- exprs.flatMap(_.collect { case a: AggregateExpression => a })
+ def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
+ exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}
+ /**
+ * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
+ */
+ object Aggregation extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case p: logical.Aggregate =>
+ val converted =
+ aggregate.Utils.tryConvert(
+ p,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled)
+ converted match {
+ case None => Nil // Cannot convert to new aggregation code path.
+ case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
+ // Extracts all distinct aggregate expressions from the resultExpressions.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ // For those distinct aggregate expressions, we create a map from the
+ // aggregate function to the corresponding attribute of the function.
+ val aggregateFunctionMap = aggregateExpressions.map { agg =>
+ val aggregateFunction = agg.aggregateFunction
+ (aggregateFunction, agg.isDistinct) ->
+ Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ }.toMap
+
+ val (functionsWithDistinct, functionsWithoutDistinct) =
+ aggregateExpressions.partition(_.isDistinct)
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ // This is a sanity check. We should not reach here when we have multiple distinct
+ // column sets (aggregate.NewAggregation will not match).
+ sys.error(
+ "Multiple distinct column sets are not supported by the new aggregation" +
+ "code path.")
+ }
+
+ val aggregateOperator =
+ if (functionsWithDistinct.isEmpty) {
+ aggregate.Utils.planAggregateWithoutDistinct(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateFunctionMap,
+ resultExpressions,
+ planLater(child))
+ } else {
+ aggregate.Utils.planAggregateWithOneDistinct(
+ groupingExpressions,
+ functionsWithDistinct,
+ functionsWithoutDistinct,
+ aggregateFunctionMap,
+ resultExpressions,
+ planLater(child))
+ }
+
+ aggregateOperator
+ }
+
+ case _ => Nil
+ }
+ }
+
+
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
@@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
- case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+ case a @ logical.Aggregate(group, agg, child) => {
+ val useNewAggregation =
+ aggregate.Utils.tryConvert(
+ a,
+ sqlContext.conf.useSqlAggregate2,
+ sqlContext.conf.codegenEnabled).isDefined
+ if (useNewAggregation) {
+ // If this logical.Aggregate can be planned to use new aggregation code path
+ // (i.e. it can be planned by the Strategy Aggregation), we will not use the old
+ // aggregation code path.
+ Nil
+ } else {
+ execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+ }
+ }
case logical.Window(projectList, windowExpressions, spec, child) =>
execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
new file mode 100644
index 0000000000..0c9082897f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
+
+case class Aggregate2Sort(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def references: AttributeSet = {
+ val referencesInResults =
+ AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
+
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++
+ aggregateExpressions.flatMap(_.references) ++
+ referencesInResults)
+ }
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ // TODO: We should not sort the input rows if they are just in reversed order.
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ // It is possible that the child.outputOrdering starts with the required
+ // ordering expressions (e.g. we require [a] as the sort expression and the
+ // child's outputOrdering is [a, b]). We can only guarantee the output rows
+ // are sorted by values of groupingExpressions.
+ groupingExpressions.map(SortOrder(_, Ascending))
+ }
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+ if (aggregateExpressions.length == 0) {
+ new GroupingIterator(
+ groupingExpressions,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ } else {
+ val aggregationIterator: SortAggregationIterator = {
+ aggregateExpressions.map(_.mode).distinct.toList match {
+ case Partial :: Nil =>
+ new PartialSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case PartialMerge :: Nil =>
+ new PartialMergeSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case Final :: Nil =>
+ new FinalSortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateAttributes,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ case other =>
+ sys.error(
+ s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
+ s"modes $other in this operator.")
+ }
+ }
+
+ aggregationIterator
+ }
+ }
+ }
+}
+
+case class FinalAndCompleteAggregate2Sort(
+ previousGroupingExpressions: Seq[NamedExpression],
+ groupingExpressions: Seq[NamedExpression],
+ finalAggregateExpressions: Seq[AggregateExpression2],
+ finalAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+ override def references: AttributeSet = {
+ val referencesInResults =
+ AttributeSet(resultExpressions.flatMap(_.references)) --
+ AttributeSet(finalAggregateExpressions) --
+ AttributeSet(completeAggregateExpressions)
+
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++
+ finalAggregateExpressions.flatMap(_.references) ++
+ completeAggregateExpressions.flatMap(_.references) ++
+ referencesInResults)
+ }
+
+ override def requiredChildDistribution: List[Distribution] = {
+ if (groupingExpressions.isEmpty) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ child.execute().mapPartitions { iter =>
+
+ new FinalAndCompleteSortAggregationIterator(
+ previousGroupingExpressions.length,
+ groupingExpressions,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ iter)
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
new file mode 100644
index 0000000000..ce1cbdc9cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -0,0 +1,749 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.NullType
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An iterator used to evaluate aggregate functions. It assumes that input rows
+ * are already grouped by values of `groupingExpressions`.
+ */
+private[sql] abstract class SortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends Iterator[InternalRow] {
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Static fields for this iterator
+ ///////////////////////////////////////////////////////////////////////////
+
+ protected val aggregateFunctions: Array[AggregateFunction2] = {
+ var bufferOffset = initialBufferOffset
+ val functions = new Array[AggregateFunction2](aggregateExpressions.length)
+ var i = 0
+ while (i < aggregateExpressions.length) {
+ val func = aggregateExpressions(i).aggregateFunction
+ val funcWithBoundReferences = aggregateExpressions(i).mode match {
+ case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, inputAttributes)
+ case _ => func
+ }
+ // Set bufferOffset for this function. It is important that setting bufferOffset
+ // happens after all potential bindReference operations because bindReference
+ // will create a new instance of the function.
+ funcWithBoundReferences.bufferOffset = bufferOffset
+ bufferOffset += funcWithBoundReferences.bufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions.
+ protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ aggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // Positions of those non-algebraic aggregate functions in aggregateFunctions.
+ // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are non-algebraic aggregate functions.
+ // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+ protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ aggregateFunctions(i) match {
+ case agg: AlgebraicAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
+ }
+
+ // This is used to project expressions for the grouping expressions.
+ protected val groupGenerator =
+ newMutableProjection(groupingExpressions, inputAttributes)()
+
+ // The underlying buffer shared by all aggregate functions.
+ protected val buffer: MutableRow = {
+ // The number of elements of the underlying buffer of this operator.
+ // All aggregate functions are sharing this underlying buffer and they find their
+ // buffer values through bufferOffset.
+ var size = initialBufferOffset
+ var i = 0
+ while (i < aggregateFunctions.length) {
+ size += aggregateFunctions(i).bufferSchema.length
+ i += 1
+ }
+ new GenericMutableRow(size)
+ }
+
+ protected val joinedRow = new JoinedRow4
+
+ protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
+
+ // This projection is used to initialize buffer values for all AlgebraicAggregates.
+ protected val algebraicInitialProjection = {
+ val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.initialValues
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(initExpressions, Nil)().target(buffer)
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Mutable states
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The partition key of the current partition.
+ protected var currentGroupingKey: InternalRow = _
+ // The partition key of next partition.
+ protected var nextGroupingKey: InternalRow = _
+ // The first row of next partition.
+ protected var firstRowInNextGroup: InternalRow = _
+ // Indicates if we has new group of rows to process.
+ protected var hasNewGroup: Boolean = true
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Private methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ /** Initializes buffer values for all aggregate functions. */
+ protected def initializeBuffer(): Unit = {
+ algebraicInitialProjection(EmptyRow)
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ protected def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+
+ /** Processes rows in the current group. It will stop when it find a new group. */
+ private def processCurrentGroup(): Unit = {
+ currentGroupingKey = nextGroupingKey
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(firstRowInNextGroup)
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ while (inputIter.hasNext && !findNextPartition) {
+ val currentRow = inputIter.next()
+ // Get the grouping key based on the grouping expressions.
+ // For the below compare method, we do not need to make a copy of groupingKey.
+ val groupingKey = groupGenerator(currentRow)
+ // Check if the current row belongs the current input row.
+ currentGroupingKey.equals(groupingKey)
+
+ if (currentGroupingKey == groupingKey) {
+ processRow(currentRow)
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ nextGroupingKey = groupingKey.copy()
+ firstRowInNextGroup = currentRow.copy()
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the iter.
+ if (!findNextPartition) {
+ hasNewGroup = false
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Public methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = hasNewGroup
+
+ override final def next(): InternalRow = {
+ if (hasNext) {
+ // Process the current group.
+ processCurrentGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput()
+ // Initilize buffer values for the next group.
+ initializeBuffer()
+
+ outputRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods that need to be implemented
+ ///////////////////////////////////////////////////////////////////////////
+
+ protected def initialBufferOffset: Int
+
+ protected def processRow(row: InternalRow): Unit
+
+ protected def generateOutput(): InternalRow
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Initialize this iterator
+ ///////////////////////////////////////////////////////////////////////////
+
+ initialize()
+}
+
+/**
+ * An iterator only used to group input rows according to values of `groupingExpressions`.
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ */
+class GroupingIterator(
+ groupingExpressions: Seq[NamedExpression],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ Nil,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ private val resultProjection =
+ newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
+
+ override protected def initialBufferOffset: Int = 0
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Since we only do grouping, there is nothing to do at here.
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ resultProjection(currentGroupingKey)
+ }
+}
+
+/**
+ * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // This projection is used to update buffer values for all AlgebraicAggregates.
+ private val algebraicUpdateProjection = {
+ val bufferSchema = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ }
+ val updateExpressions = aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+ }
+
+ override protected def initialBufferOffset: Int = 0
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicUpdateProjection(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).update(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // We just output the grouping expressions and the underlying buffer.
+ joinedRow(currentGroupingKey, buffer).copy()
+ }
+}
+
+/**
+ * An iterator used to do partial merge aggregations (for those aggregate functions with mode
+ * PartialMerge). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialMergeSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ private val placeholderAttribtues =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata =
+ placeholderAttribtues ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to extract aggregation buffers from the underlying buffer.
+ // We need it because the underlying buffer has placeholders at its beginning.
+ private val extractsBufferValues = {
+ val expressions = aggregateFunctions.flatMap {
+ case agg => agg.bufferAttributes
+ }
+
+ newMutableProjection(expressions, inputAttributes)()
+ }
+
+ override protected def initialBufferOffset: Int = groupingExpressions.length
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // We output grouping expressions and aggregation buffers.
+ joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+ }
+}
+
+/**
+ * An iterator used to do final aggregations (for those aggregate functions with mode
+ * Final). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalSortAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ aggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // The result of aggregate functions.
+ private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
+
+ // The projection used to generate the output rows of this operator.
+ // This is only used when we are generating final results of aggregate functions.
+ private val resultProjection =
+ newMutableProjection(
+ resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
+
+ private val offsetAttributes =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ private val algebraicMergeProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to evaluate all AlgebraicAggregates.
+ private val algebraicEvalProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val evalExpressions = aggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+
+ newMutableProjection(evalExpressions, bufferSchemata)()
+ }
+
+ override protected def initialBufferOffset: Int = groupingExpressions.length
+
+ override def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expression, we need to generate a single row as the output.
+ initializeBuffer()
+ // Right now, the buffer only contains initial buffer values. Because
+ // merging two buffers with initial values will generate a row that
+ // still store initial values. We set the currentRow as the copy of the current buffer.
+ val currentRow = buffer.copy()
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+ }
+
+ override protected def processRow(row: InternalRow): Unit = {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(buffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ nonAlgebraicAggregateFunctionPositions(i),
+ nonAlgebraicAggregateFunctions(i).eval(buffer))
+ i += 1
+ }
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+ }
+}
+
+/**
+ * An iterator used to do both final aggregations (for those aggregate functions with mode
+ * Final) and complete aggregations (for those aggregate functions with mode Complete).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
+ * col1 to colM are columns used by aggregate functions with Complete mode.
+ * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
+ * Final mode.
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
+ * The first N placeholders represent slots of grouping expressions.
+ * Then, next M placeholders represent slots of col1 to colM.
+ * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
+ * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
+ * Complete. The reason that we have placeholders at here is to make our underlying buffer
+ * have the same length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalAndCompleteSortAggregationIterator(
+ override protected val initialBufferOffset: Int,
+ groupingExpressions: Seq[NamedExpression],
+ finalAggregateExpressions: Seq[AggregateExpression2],
+ finalAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow])
+ extends SortAggregationIterator(
+ groupingExpressions,
+ // TODO: document the ordering
+ finalAggregateExpressions ++ completeAggregateExpressions,
+ newMutableProjection,
+ inputAttributes,
+ inputIter) {
+
+ // The result of aggregate functions.
+ private val aggregateResult: MutableRow =
+ new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
+
+ // The projection used to generate the output rows of this operator.
+ // This is only used when we are generating final results of aggregate functions.
+ private val resultProjection = {
+ val inputSchema =
+ groupingExpressions.map(_.toAttribute) ++
+ finalAggregateAttributes ++
+ completeAggregateAttributes
+ newMutableProjection(resultExpressions, inputSchema)()
+ }
+
+ private val offsetAttributes =
+ Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+ // All aggregate functions with mode Final.
+ private val finalAggregateFunctions: Array[AggregateFunction2] = {
+ val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
+ var i = 0
+ while (i < finalAggregateExpressions.length) {
+ functions(i) = aggregateFunctions(i)
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions with mode Final.
+ private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ finalAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // All aggregate functions with mode Complete.
+ private val completeAggregateFunctions: Array[AggregateFunction2] = {
+ val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
+ var i = 0
+ while (i < completeAggregateExpressions.length) {
+ functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
+ i += 1
+ }
+ functions
+ }
+
+ // All non-algebraic aggregate functions with mode Complete.
+ private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }.toArray
+ }
+
+ // This projection is used to merge buffer values for all AlgebraicAggregates with mode
+ // Final.
+ private val finalAlgebraicMergeProjection = {
+ val numCompleteOffsetAttributes =
+ completeAggregateFunctions.map(_.bufferAttributes.length).sum
+ val completeOffsetAttributes =
+ Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
+ val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
+
+ val bufferSchemata =
+ offsetAttributes ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ } ++ completeOffsetAttributes
+ val mergeExpressions =
+ placeholderExpressions ++ finalAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
+
+ newMutableProjection(mergeExpressions, bufferSchemata)()
+ }
+
+ // This projection is used to update buffer values for all AlgebraicAggregates with mode
+ // Complete.
+ private val completeAlgebraicUpdateProjection = {
+ val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
+ val finalOffsetAttributes =
+ Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
+ val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+
+ val bufferSchema =
+ offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ }
+ val updateExpressions =
+ placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+ }
+
+ // This projection is used to evaluate all AlgebraicAggregates.
+ private val algebraicEvalProjection = {
+ val bufferSchemata =
+ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.bufferAttributes
+ case agg: AggregateFunction2 => agg.bufferAttributes
+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+ case agg: AggregateFunction2 => agg.cloneBufferAttributes
+ }
+ val evalExpressions = aggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+
+ newMutableProjection(evalExpressions, bufferSchemata)()
+ }
+
+ override def initialize(): Unit = {
+ if (inputIter.hasNext) {
+ initializeBuffer()
+ val currentRow = inputIter.next().copy()
+ // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+ // we are making a copy at here.
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ if (groupingExpressions.isEmpty) {
+ // If there is no grouping expression, we need to generate a single row as the output.
+ initializeBuffer()
+ // Right now, the buffer only contains initial buffer values. Because
+ // merging two buffers with initial values will generate a row that
+ // still store initial values. We set the currentRow as the copy of the current buffer.
+ val currentRow = buffer.copy()
+ nextGroupingKey = groupGenerator(currentRow).copy()
+ firstRowInNextGroup = currentRow
+ } else {
+ // This iter is an empty one.
+ hasNewGroup = false
+ }
+ }
+ }
+
+ override protected def processRow(row: InternalRow): Unit = {
+ val input = joinedRow(buffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
+ i += 1
+ }
+
+ // For all aggregate functions with mode Final, merge buffers.
+ finalAlgebraicMergeProjection.target(buffer)(input)
+ i = 0
+ while (i < finalNonAlgebraicAggregateFunctions.length) {
+ finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
+ i += 1
+ }
+ }
+
+ override protected def generateOutput(): InternalRow = {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(buffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ nonAlgebraicAggregateFunctionPositions(i),
+ nonAlgebraicAggregateFunctions(i).eval(buffer))
+ i += 1
+ }
+
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
new file mode 100644
index 0000000000..1cb27710e0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+ val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+ case array: ArrayType => true
+ case map: MapType => true
+ case struct: StructType => true
+ case _ => false
+ }
+
+ !hasComplexTypes
+ }
+
+ private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate if supportsGroupingKeySchema(p) =>
+ val converted = p.transformExpressionsDown {
+ case expressions.Average(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Average(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Count(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ // We do not support multiple COUNT DISTINCT columns for now.
+ case expressions.CountDistinct(children) if children.length == 1 =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(children.head),
+ mode = aggregate.Complete,
+ isDistinct = true)
+
+ case expressions.First(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.First(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Last(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Last(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Max(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Max(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Min(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Min(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Sum(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.SumDistinct(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = true)
+ }
+ // Check if there is any expressions.AggregateExpression1 left.
+ // If so, we cannot convert this plan.
+ val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+ // For every expressions, check if it contains AggregateExpression1.
+ expr.find {
+ case agg: expressions.AggregateExpression1 => true
+ case other => false
+ }.isDefined
+ }
+
+ // Check if there are multiple distinct columns.
+ val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val hasMultipleDistinctColumnSets =
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ true
+ } else {
+ false
+ }
+
+ if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+ case other => None
+ }
+
+ private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+ // If the plan cannot be converted, we will do a final round check to if the original
+ // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+ // we need to throw an exception.
+ val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg.aggregateFunction
+ }
+ }.distinct
+ if (aggregateFunction2s.nonEmpty) {
+ // For functions implemented based on the new interface, prepare a list of function names.
+ val invalidFunctions = {
+ if (aggregateFunction2s.length > 1) {
+ s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+ s"and ${aggregateFunction2s.head.nodeName} are"
+ } else {
+ s"${aggregateFunction2s.head.nodeName} is"
+ }
+ }
+ val errorMessage =
+ s"${invalidFunctions} implemented based on the new Aggregate Function " +
+ s"interface and it cannot be used with functions implemented based on " +
+ s"the old Aggregate Function interface."
+ throw new AnalysisException(errorMessage)
+ }
+ }
+
+ def tryConvert(
+ plan: LogicalPlan,
+ useNewAggregation: Boolean,
+ codeGenEnabled: Boolean): Option[Aggregate] = plan match {
+ case p: Aggregate if useNewAggregation && codeGenEnabled =>
+ val converted = tryConvert(p)
+ if (converted.isDefined) {
+ converted
+ } else {
+ checkInvalidAggregateFunction2(p)
+ None
+ }
+ case p: Aggregate =>
+ checkInvalidAggregateFunction2(p)
+ None
+ case other => None
+ }
+
+ def planAggregateWithoutDistinct(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+ // 1. Create an Aggregate Operator for partial aggregations.
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+ val partialAggregateExpressions = aggregateExpressions.map {
+ case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ AggregateExpression2(aggregateFunction, Partial, isDistinct)
+ }
+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
+ }
+ val partialAggregate =
+ Aggregate2Sort(
+ None: Option[Seq[Expression]],
+ namedGroupingExpressions.map(_._2),
+ partialAggregateExpressions,
+ partialAggregateAttributes,
+ namedGroupingAttributes ++ partialAggregateAttributes,
+ child)
+
+ // 2. Create an Aggregate Operator for final aggregations.
+ val finalAggregateExpressions = aggregateExpressions.map {
+ case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ AggregateExpression2(aggregateFunction, Final, isDistinct)
+ }
+ val finalAggregateAttributes =
+ finalAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ val finalAggregate = Aggregate2Sort(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ rewrittenResultExpressions,
+ partialAggregate)
+
+ finalAggregate :: Nil
+ }
+
+ def planAggregateWithOneDistinct(
+ groupingExpressions: Seq[Expression],
+ functionsWithDistinct: Seq[AggregateExpression2],
+ functionsWithoutDistinct: Seq[AggregateExpression2],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+
+ // 1. Create an Aggregate Operator for partial aggregations.
+ // The grouping expressions are original groupingExpressions and
+ // distinct columns. For example, for avg(distinct value) ... group by key
+ // the grouping expressions of this Aggregate Operator will be [key, value].
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+ // It is safe to call head at here since functionsWithDistinct has at least one
+ // AggregateExpression2.
+ val distinctColumnExpressions =
+ functionsWithDistinct.head.aggregateFunction.children
+ val namedDistinctColumnExpressions = distinctColumnExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
+ val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+
+ val partialAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, Partial, false)
+ }
+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
+ }
+ val partialAggregate =
+ Aggregate2Sort(
+ None: Option[Seq[Expression]],
+ (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
+ partialAggregateExpressions,
+ partialAggregateAttributes,
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
+ child)
+
+ // 2. Create an Aggregate Operator for partial merge aggregations.
+ val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, PartialMerge, false)
+ }
+ val partialMergeAggregateAttributes =
+ partialMergeAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val partialMergeAggregate =
+ Aggregate2Sort(
+ Some(namedGroupingAttributes),
+ namedGroupingAttributes ++ distinctColumnAttributes,
+ partialMergeAggregateExpressions,
+ partialMergeAggregateAttributes,
+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
+ partialAggregate)
+
+ // 3. Create an Aggregate Operator for partial merge aggregations.
+ val finalAggregateExpressions = functionsWithoutDistinct.map {
+ case AggregateExpression2(aggregateFunction, mode, _) =>
+ AggregateExpression2(aggregateFunction, Final, false)
+ }
+ val finalAggregateAttributes =
+ finalAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ }
+ val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+ // Children of an AggregateFunction with DISTINCT keyword has already
+ // been evaluated. At here, we need to replace original children
+ // to AttributeReferences.
+ case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ val rewrittenAggregateFunction = aggregateFunction.transformDown {
+ case expr if distinctColumnExpressionMap.contains(expr) =>
+ distinctColumnExpressionMap(expr).toAttribute
+ }.asInstanceOf[AggregateFunction2]
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ val rewrittenAggregateExpression =
+ AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+
+ val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+ (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+ }.unzip
+
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
+ namedGroupingAttributes ++ distinctColumnAttributes,
+ namedGroupingAttributes,
+ finalAggregateExpressions,
+ finalAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ rewrittenResultExpressions,
+ partialMergeAggregate)
+
+ finalAndCompleteAggregate :: Nil
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
new file mode 100644
index 0000000000..6c49a906c8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+
+/**
+ * The abstract class for implementing user-defined aggregate function.
+ */
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+ /**
+ * A [[StructType]] represents data types of input arguments of this aggregate function.
+ * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+ * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+ *
+ * ```
+ * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * input argument. Users can choose names to identify the input arguments.
+ */
+ def inputSchema: StructType
+
+ /**
+ * A [[StructType]] represents data types of values in the aggregation buffer.
+ * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+ * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+ * the returned [[StructType]] will look like
+ *
+ * ```
+ * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * buffer value. Users can choose names to identify the input arguments.
+ */
+ def bufferSchema: StructType
+
+ /**
+ * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+ */
+ def returnDataType: DataType
+
+ /** Indicates if this function is deterministic. */
+ def deterministic: Boolean
+
+ /**
+ * Initializes the given aggregation buffer. Initial values set by this method should satisfy
+ * the condition that when merging two buffers with initial values, the new buffer should
+ * still store initial values.
+ */
+ def initialize(buffer: MutableAggregationBuffer): Unit
+
+ /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+ /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+ /**
+ * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+ * aggregation buffer.
+ */
+ def evaluate(buffer: Row): Any
+}
+
+private[sql] abstract class AggregationBuffer(
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int)
+ extends Row {
+
+ override def length: Int = toCatalystConverters.length
+
+ protected val offsets: Array[Int] = {
+ val newOffsets = new Array[Int](length)
+ var i = 0
+ while (i < newOffsets.length) {
+ newOffsets(i) = bufferOffset + i
+ i += 1
+ }
+ newOffsets
+ }
+}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+class MutableAggregationBuffer private[sql] (
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingBuffer: MutableRow)
+ extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ toScalaConverters(i)(underlyingBuffer(offsets(i)))
+ }
+
+ def update(i: Int, value: Any): Unit = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not update ${i}th value in this buffer because it only has $length values.")
+ }
+ underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+ }
+
+ override def copy(): MutableAggregationBuffer = {
+ new MutableAggregationBuffer(
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingBuffer)
+ }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+class InputAggregationBuffer private[sql] (
+ toCatalystConverters: Array[Any => Any],
+ toScalaConverters: Array[Any => Any],
+ bufferOffset: Int,
+ var underlyingInputBuffer: Row)
+ extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+ override def get(i: Int): Any = {
+ if (i >= length || i < 0) {
+ throw new IllegalArgumentException(
+ s"Could not access ${i}th value in this buffer because it only has $length values.")
+ }
+ toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
+ }
+
+ override def copy(): InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ toCatalystConverters,
+ toScalaConverters,
+ bufferOffset,
+ underlyingInputBuffer)
+ }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+case class ScalaUDAF(
+ children: Seq[Expression],
+ udaf: UserDefinedAggregateFunction)
+ extends AggregateFunction2 with Logging {
+
+ require(
+ children.length == udaf.inputSchema.length,
+ s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+ s"but ${children.length} are provided.")
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = udaf.returnDataType
+
+ override def deterministic: Boolean = udaf.deterministic
+
+ override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+ override val bufferSchema: StructType = udaf.bufferSchema
+
+ override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
+
+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+ val childrenSchema: StructType = {
+ val inputFields = children.zipWithIndex.map {
+ case (child, index) =>
+ StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
+ }
+ StructType(inputFields)
+ }
+
+ lazy val inputProjection = {
+ val inputAttributes = childrenSchema.toAttributes
+ log.debug(
+ s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+ try {
+ GenerateMutableProjection.generate(children, inputAttributes)()
+ } catch {
+ case e: Exception =>
+ log.error("Failed to generate mutable projection, fallback to interpreted", e)
+ new InterpretedMutableProjection(children, inputAttributes)
+ }
+ }
+
+ val inputToScalaConverters: Any => Any =
+ CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+ val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ }
+
+ val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToScalaConverter(field.dataType)
+ }
+
+ lazy val inputAggregateBuffer: InputAggregationBuffer =
+ new InputAggregationBuffer(
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+ lazy val mutableAggregateBuffer: MutableAggregationBuffer =
+ new MutableAggregationBuffer(
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ bufferOffset,
+ null)
+
+
+ override def initialize(buffer: MutableRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.initialize(mutableAggregateBuffer)
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer
+
+ udaf.update(
+ mutableAggregateBuffer,
+ inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ mutableAggregateBuffer.underlyingBuffer = buffer1
+ inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+ udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+ }
+
+ override def eval(buffer: InternalRow = null): Any = {
+ inputAggregateBuffer.underlyingInputBuffer = buffer
+
+ udaf.evaluate(inputAggregateBuffer)
+ }
+
+ override def toString: String = {
+ s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+ }
+
+ override def nodeName: String = udaf.getClass.getSimpleName
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 28159cbd5a..bfeecbe8b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2420,7 +2420,7 @@ object functions {
* @since 1.5.0
*/
def callUDF(udfName: String, cols: Column*): Column = {
- UnresolvedFunction(udfName, cols.map(_.expr))
+ UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
}
/**
@@ -2449,7 +2449,7 @@ object functions {
exprs(i) = cols(i).expr
i += 1
}
- UnresolvedFunction(udfName, exprs)
+ UnresolvedFunction(udfName, exprs, isDistinct = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index beee10173f..ab8dce603c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
@@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
var hasGeneratedAgg = false
df.queryExecution.executedPlan.foreach {
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+ case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
case _ =>
}
if (!hasGeneratedAgg) {
@@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(0, null, 0) :: Nil)
+ Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3dd24130af..3d71deb13e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
@@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution}
class PlannerSuite extends SparkFunSuite {
+ private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
+ val planned =
+ plannedOption.getOrElse(
+ fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+ // For the new aggregation code path, there will be three aggregate operator for
+ // distinct aggregations.
+ assert(
+ aggregations.size == 2 || aggregations.size == 3,
+ s"The plan of query $query does not have partial aggregations.")
+ }
+
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
@@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite {
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
- val planned = HashAggregation(query).head
- val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
-
- assert(aggregations.size === 2)
+ testPartialAggregationPlan(query)
}
test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
- val planned = HashAggregation(query)
- assert(planned.nonEmpty)
+ testPartialAggregationPlan(query)
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
- val planned = HashAggregation(query)
- assert(planned.nonEmpty)
+ testPartialAggregationPlan(query)
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 31a49a3683..24a758f531 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite
"windowing_adjust_rowcontainer_sz"
)
+ // Only run those query tests in the realWhileList (do not try other ignored query files).
override def testCases: Seq[(String, File)] = super.testCases.filter {
case (name, _) => realWhiteList.contains(name)
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
index f458567e5d..1fe4fe9629 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.execution
+import java.io.File
+
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.test.TestHive
@@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
"join_reorder4",
"join_star"
)
+
+ // Only run those query tests in the realWhileList (do not try other ignored query files).
+ override def testCases: Seq[(String, File)] = super.testCases.filter {
+ case (name, _) => realWhiteList.contains(name)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cec7685bb6..4cdb83c511 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
DataSinks,
Scripts,
HashAggregation,
+ Aggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f5574509b0..8518e333e8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* UDFs - Must be last otherwise will preempt built in functions */
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr))
+ UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
+ // Aggregate function with DISTINCT keyword.
+ case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, UnresolvedStar(None) :: Nil)
+ UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
/* Literals */
case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4d23c7035c..3259b50acc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction(
private[hive] case class HiveGenericUDAF(
funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
+ children: Seq[Expression]) extends AggregateExpression1
with HiveInspectors {
type UDFType = AbstractGenericUDAFResolver
@@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF(
/** It is used as a wrapper for the hive functions which uses UDAF interface */
private[hive] case class HiveUDAF(
funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression
+ children: Seq[Expression]) extends AggregateExpression1
with HiveInspectors {
type UDFType = UDAF
@@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF(
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
- base: AggregateExpression,
+ base: AggregateExpression1,
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction
+ extends AggregateFunction1
with HiveInspectors {
def this() = this(null, null, null)
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
new file mode 100644
index 0000000000..5c9d0e97a9
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class MyDoubleAvg extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleAvg() {
+ List<StructField> inputfields = new ArrayList<StructField>();
+ inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputfields);
+
+ List<StructField> bufferFields = new ArrayList<StructField>();
+ bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
+ bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType returnDataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, null);
+ buffer.update(1, 0L);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ if (buffer.isNullAt(0)) {
+ buffer.update(0, input.getDouble(0));
+ buffer.update(1, 1L);
+ } else {
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ buffer.update(1, buffer.getLong(1) + 1L);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ buffer1.update(0, buffer2.getDouble(0));
+ buffer1.update(1, buffer2.getLong(1));
+ } else {
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ return null;
+ } else {
+ return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
+ }
+ }
+}
+
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
new file mode 100644
index 0000000000..1d4587a27c
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.Row;
+
+public class MyDoubleSum extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleSum() {
+ List<StructField> inputfields = new ArrayList<StructField>();
+ inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputfields);
+
+ List<StructField> bufferFields = new ArrayList<StructField>();
+ bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType returnDataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ buffer.update(0, null);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ if (!input.isNullAt(0)) {
+ if (buffer.isNullAt(0)) {
+ buffer.update(0, input.getDouble(0));
+ } else {
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ buffer1.update(0, buffer2.getDouble(0));
+ } else {
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ return null;
+ } else {
+ return buffer.getDouble(0);
+ }
+ }
+}
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
new file mode 100644
index 0000000000..573541ac97
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
new file mode 100644
index 0000000000..44b2a42cc2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
@@ -0,0 +1 @@
+unhex(str) - Converts hexadecimal argument to binary
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
new file mode 100644
index 0000000000..97af3b812a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
@@ -0,0 +1,14 @@
+unhex(str) - Converts hexadecimal argument to binary
+Performs the inverse operation of HEX(str). That is, it interprets
+each pair of hexadecimal digits in the argument as a number and
+converts it to the byte representation of the number. The
+resulting characters are returned as a binary string.
+
+Example:
+> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1;
+'MySQL'
+
+The characters in the argument string must be legal hexadecimal
+digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters
+any nonhexadecimal digits in the argument, it returns NULL. Also,
+if there are an odd number of characters a leading 0 is appended.
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
new file mode 100644
index 0000000000..b4a6f2b692
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
@@ -0,0 +1 @@
+MySQL 1267 a -4
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
new file mode 100644
index 0000000000..3a67adaf0a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
@@ -0,0 +1 @@
+NULL NULL NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
new file mode 100644
index 0000000000..0375eb79ad
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.scalatest.BeforeAndAfterAll
+import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+
+class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
+
+ override val sqlContext = TestHive
+ import sqlContext.implicits._
+
+ var originalUseAggregate2: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
+ sqlContext.sql("set spark.sql.useAggregate2=true")
+ val data1 = Seq[(Integer, Integer)](
+ (1, 10),
+ (null, -60),
+ (1, 20),
+ (1, 30),
+ (2, 0),
+ (null, -10),
+ (2, -1),
+ (2, null),
+ (2, null),
+ (null, 100),
+ (3, null),
+ (null, null),
+ (3, null)).toDF("key", "value")
+ data1.write.saveAsTable("agg1")
+
+ val data2 = Seq[(Integer, Integer, Integer)](
+ (1, 10, -10),
+ (null, -60, 60),
+ (1, 30, -30),
+ (1, 30, 30),
+ (2, 1, 1),
+ (null, -10, 10),
+ (2, -1, null),
+ (2, 1, 1),
+ (2, null, 1),
+ (null, 100, -10),
+ (3, null, 3),
+ (null, null, null),
+ (3, null, null)).toDF("key", "value1", "value2")
+ data2.write.saveAsTable("agg2")
+
+ val emptyDF = sqlContext.createDataFrame(
+ sqlContext.sparkContext.emptyRDD[Row],
+ StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
+ emptyDF.registerTempTable("emptyTable")
+
+ // Register UDAFs
+ sqlContext.udaf.register("mydoublesum", new MyDoubleSum)
+ sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg)
+ }
+
+ override def afterAll(): Unit = {
+ sqlContext.sql("DROP TABLE IF EXISTS agg1")
+ sqlContext.sql("DROP TABLE IF EXISTS agg2")
+ sqlContext.dropTempTable("emptyTable")
+ sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
+ }
+
+ test("empty table") {
+ // If there is no GROUP BY clause and the table is empty, we will generate a single row.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(key),
+ | COUNT(value),
+ | FIRST(key),
+ | LAST(value),
+ | MAX(key),
+ | MIN(value),
+ | SUM(key)
+ |FROM emptyTable
+ """.stripMargin),
+ Row(null, 0, 0, 0, null, null, null, null, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(key),
+ | COUNT(value),
+ | FIRST(key),
+ | LAST(value),
+ | MAX(key),
+ | MIN(value),
+ | SUM(key),
+ | COUNT(DISTINCT value)
+ |FROM emptyTable
+ """.stripMargin),
+ Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil)
+
+ // If there is a GROUP BY clause and the table is empty, there is no output.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(value),
+ | COUNT(*),
+ | COUNT(value),
+ | FIRST(value),
+ | LAST(value),
+ | MAX(value),
+ | MIN(value),
+ | SUM(value),
+ | COUNT(DISTINCT value)
+ |FROM emptyTable
+ |GROUP BY key
+ """.stripMargin),
+ Nil)
+ }
+
+ test("only do grouping") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT DISTINCT value1, key
+ |FROM agg2
+ """.stripMargin),
+ Row(10, 1) ::
+ Row(-60, null) ::
+ Row(30, 1) ::
+ Row(1, 2) ::
+ Row(-10, null) ::
+ Row(-1, 2) ::
+ Row(null, 2) ::
+ Row(100, null) ::
+ Row(null, 3) ::
+ Row(null, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT value1, key
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(10, 1) ::
+ Row(-60, null) ::
+ Row(30, 1) ::
+ Row(1, 2) ::
+ Row(-10, null) ::
+ Row(-1, 2) ::
+ Row(null, 2) ::
+ Row(100, null) ::
+ Row(null, 3) ::
+ Row(null, null) :: Nil)
+ }
+
+ test("case in-sensitive resolution") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value), kEY - 100
+ |FROM agg1
+ |GROUP BY Key - 100
+ """.stripMargin),
+ Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT sum(distinct value1), kEY - 100, count(distinct value1)
+ |FROM agg2
+ |GROUP BY Key - 100
+ """.stripMargin),
+ Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT valUe * key - 100
+ |FROM agg1
+ |GROUP BY vAlue * keY - 100
+ """.stripMargin),
+ Row(-90) ::
+ Row(-80) ::
+ Row(-70) ::
+ Row(-100) ::
+ Row(-102) ::
+ Row(null) :: Nil)
+ }
+
+ test("test average no key in output") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil)
+ }
+
+ test("test average") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key, avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value), key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value) + 1.5, key + 10
+ |FROM agg1
+ |GROUP BY key + 10
+ """.stripMargin),
+ Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(value) FROM agg1
+ """.stripMargin),
+ Row(11.125) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT avg(null)
+ """.stripMargin),
+ Row(null) :: Nil)
+ }
+
+ test("udaf") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoublesum(value + 1.5 * key),
+ | mydoubleavg(value),
+ | avg(value - key),
+ | mydoublesum(value - 1.5 * key),
+ | avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) ::
+ Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) ::
+ Row(3, null, null, null, null, null) ::
+ Row(null, null, 110.0, null, null, 10.0) :: Nil)
+ }
+
+ test("non-AlgebraicAggregate aggreguate function") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value), key
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value) FROM agg1
+ """.stripMargin),
+ Row(89.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(null)
+ """.stripMargin),
+ Row(null) :: Nil)
+ }
+
+ test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT mydoublesum(value), key, avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(60.0, 1, 20.0) ::
+ Row(-1.0, 2, -0.5) ::
+ Row(null, 3, null) ::
+ Row(30.0, null, 10.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoublesum(value + 1.5 * key),
+ | avg(value - key),
+ | key,
+ | mydoublesum(value - 1.5 * key),
+ | avg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin),
+ Row(64.5, 19.0, 1, 55.5, 20.0) ::
+ Row(5.0, -2.5, 2, -7.0, -0.5) ::
+ Row(null, null, 3, null, null) ::
+ Row(null, null, null, null, 10.0) :: Nil)
+ }
+
+ test("single distinct column set") {
+ // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | min(distinct value1),
+ | sum(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | max(distinct value1)
+ |FROM agg2
+ """.stripMargin),
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoubleavg(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | key,
+ | mydoubleavg(value1 - 1),
+ | mydoubleavg(distinct value1) * 0.1,
+ | avg(value1 + value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+ Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+ Row(null, null, 3.0, 3, null, null, null) ::
+ Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoubleavg(distinct value1),
+ | mydoublesum(value2),
+ | mydoublesum(distinct value1),
+ | mydoubleavg(distinct value1),
+ | mydoubleavg(value1)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+ Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+ Row(3, null, 3.0, null, null, null) ::
+ Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+ }
+
+ test("test count") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value2),
+ | value1,
+ | count(*),
+ | count(1),
+ | key
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(1, 10, 1, 1, 1) ::
+ Row(1, -60, 1, 1, null) ::
+ Row(2, 30, 2, 2, 1) ::
+ Row(2, 1, 2, 2, 2) ::
+ Row(1, -10, 1, 1, null) ::
+ Row(0, -1, 1, 1, 2) ::
+ Row(1, null, 1, 1, 2) ::
+ Row(1, 100, 1, 1, null) ::
+ Row(1, null, 2, 2, 3) ::
+ Row(0, null, 1, 1, null) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value2),
+ | value1,
+ | count(*),
+ | count(1),
+ | key,
+ | count(DISTINCT abs(value2))
+ |FROM agg2
+ |GROUP BY key, value1
+ """.stripMargin),
+ Row(1, 10, 1, 1, 1, 1) ::
+ Row(1, -60, 1, 1, null, 1) ::
+ Row(2, 30, 2, 2, 1, 1) ::
+ Row(2, 1, 2, 2, 2, 1) ::
+ Row(1, -10, 1, 1, null, 1) ::
+ Row(0, -1, 1, 1, 2, 0) ::
+ Row(1, null, 1, 1, 2, 1) ::
+ Row(1, 100, 1, 1, null, 1) ::
+ Row(1, null, 2, 2, 3, 1) ::
+ Row(0, null, 1, 1, null, 0) :: Nil)
+ }
+
+ test("error handling") {
+ sqlContext.sql(s"set spark.sql.useAggregate2=false")
+ var errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | sum(value + 1.5 * key),
+ | mydoublesum(value),
+ | mydoubleavg(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+ // TODO: once we support Hive UDAF in the new interface,
+ // we can remove the following two tests.
+ sqlContext.sql(s"set spark.sql.useAggregate2=true")
+ errorMessage = intercept[AnalysisException] {
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoublesum(value + 1.5 * key),
+ | stddev_samp(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).collect()
+ }.getMessage
+ assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+ // This will fall back to the old aggregate
+ val newAggregateOperators = sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | sum(value + 1.5 * key),
+ | stddev_samp(value)
+ |FROM agg1
+ |GROUP BY key
+ """.stripMargin).queryExecution.executedPlan.collect {
+ case agg: Aggregate2Sort => agg
+ }
+ val message =
+ "We should fallback to the old aggregation code path if there is any aggregate function " +
+ "that cannot be converted to the new interface."
+ assert(newAggregateOperators.isEmpty, message)
+
+ sqlContext.sql(s"set spark.sql.useAggregate2=true")
+ }
+}