aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-11-10 11:06:29 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-10 11:06:29 -0800
commite0701c75601c43f69ed27fc7c252321703db51f2 (patch)
tree52d85dfefce3da304fef585c895667f305cd8238
parent6e5fc37883ed81c3ee2338145a48de3036d19399 (diff)
downloadspark-e0701c75601c43f69ed27fc7c252321703db51f2.tar.gz
spark-e0701c75601c43f69ed27fc7c252321703db51f2.tar.bz2
spark-e0701c75601c43f69ed27fc7c252321703db51f2.zip
[SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to evaluate AggregateExpression1s
https://issues.apache.org/jira/browse/SPARK-9830 This PR contains the following main changes. * Removing `AggregateExpression1`. * Removing `Aggregate` operator, which is used to evaluate `AggregateExpression1`. * Removing planner rule used to plan `Aggregate`. * Linking `MultipleDistinctRewriter` to analyzer. * Renaming `AggregateExpression2` to `AggregateExpression` and `AggregateFunction2` to `AggregateFunction`. * Updating places where we create aggregate expression. The way to create aggregate expressions is `AggregateExpression(aggregateFunction, mode, isDistinct)`. * Changing `val`s in `DeclarativeAggregate`s that touch children of this function to `lazy val`s (when we create aggregate expression in DataFrame API, children of an aggregate function can be unresolved). Author: Yin Huai <yhuai@databricks.com> Closes #9556 from yhuai/removeAgg1.
-rw-r--r--R/pkg/R/functions.R2
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--python/pyspark/sql/tests.py2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala)235
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala1073
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala74
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala205
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala238
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala69
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala30
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala188
64 files changed, 743 insertions, 2260 deletions
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d7fd279279..0b28087029 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"),
#' @export
setMethod("approxCountDistinct",
signature(x = "Column"),
- function(x, rsd = 0.95) {
+ function(x, rsd = 0.05) {
jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
column(jc)
})
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b97c94dad8..0dd75ba7ca 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -866,7 +866,7 @@ class DataFrame(object):
This is a variant of :func:`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
+ [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
"""
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 962f676d40..6e1cbde423 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -382,7 +382,7 @@ def expr(str):
"""Parses the expression string into the column that it represents
>>> df.select(expr("length(name)")).collect()
- [Row('length(name)=5), Row('length(name)=3)]
+ [Row(length(name)=5), Row(length(name)=3)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.expr(str))
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e224574bcb..9f5f7cfdf7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1017,7 +1017,7 @@ class SQLTests(ReusedPySparkTestCase):
row = Row(a="length string", b=75)
df = self.sqlCtx.createDataFrame([row])
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
- self.assertEqual(13, result["'length(a)"])
+ self.assertEqual(13, result["length(a)"])
def test_replace(self):
schema = StructType([
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 3f351b07b3..7c2b8a9407 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean
+
+ protected[spark] def specializeSingleDistinctAggPlanning: Boolean
}
/**
@@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf {
override def caseSensitiveAnalysis: Boolean = {
throw new UnsupportedOperationException
}
+
+ protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
+ throw new UnsupportedOperationException
+ }
}
/** A CatalystConf that can be used for local testing. */
-case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf
+case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
+ protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index cd717c09f8..2a132d8b82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -22,6 +22,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DataTypeParser
@@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val function: Parser[Expression] =
( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
if (lexical.normalizeKeyword(udfName) == "count") {
- Count(Literal(1))
+ AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false)
} else {
throw new AnalysisException(s"invalid expression $udfName(*)")
}
@@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
- case "sum" => SumDistinct(exprs.head)
- case "count" => CountDistinct(exprs)
+ case "count" =>
+ aggregate.Count(exprs).toAggregateExpression(isDistinct = true)
case _ => UnresolvedFunction(udfName, exprs, isDistinct = true)
}
}
| APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
if (lexical.normalizeKeyword(udfName) == "count") {
- ApproxCountDistinct(exp)
+ AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false)
} else {
throw new AnalysisException(s"invalid function approximate $udfName")
}
@@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
{ case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
if (lexical.normalizeKeyword(udfName) == "count") {
- ApproxCountDistinct(exp, s.toDouble)
+ AggregateExpression(
+ HyperLogLogPlusPlus(exp, s.toDouble, 0, 0),
+ mode = Complete,
+ isDistinct = false)
} else {
throw new AnalysisException(s"invalid function approximate($s) $udfName")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 899ee67352..b1e14390b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
@@ -79,6 +79,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
+ DistinctAggregationRewriter(conf) ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -525,21 +526,14 @@ class Analyzer(
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
- // We get an aggregate function built based on AggregateFunction2 interface.
- // So, we wrap it in AggregateExpression2.
- case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
- // Currently, our old aggregate function interface supports SUM(DISTINCT ...)
- // and COUTN(DISTINCT ...).
- case sumDistinct: SumDistinct => sumDistinct
- case countDistinct: CountDistinct => countDistinct
- // DISTINCT is not meaningful with Max and Min.
- case max: Max if isDistinct => max
- case min: Min if isDistinct => min
- // For other aggregate functions, DISTINCT keyword is not supported for now.
- // Once we converted to the new code path, we will allow using DISTINCT keyword.
- case other: AggregateExpression1 if isDistinct =>
- failAnalysis(s"$name does not support DISTINCT keyword.")
- // If it does not have DISTINCT keyword, we will return it as is.
+ // DISTINCT is not meaningful for a Max or a Min.
+ case max: Max if isDistinct =>
+ AggregateExpression(max, Complete, isDistinct = false)
+ case min: Min if isDistinct =>
+ AggregateExpression(min, Complete, isDistinct = false)
+ // We get an aggregate function, we need to wrap it in an AggregateExpression.
+ case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct)
+ // This function is not an aggregate function, just return the resolved one.
case other => other
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 98d6637c06..8322e9930c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -108,7 +109,19 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
- case _: AggregateExpression => // OK
+ case aggExpr: AggregateExpression =>
+ // TODO: Is it possible that the child of a agg function is another
+ // agg function?
+ aggExpr.aggregateFunction.children.foreach {
+ // This is just a sanity check, our analysis rule PullOutNondeterministic should
+ // already pull out those nondeterministic expressions and evaluate them in
+ // a Project node.
+ case child if !child.deterministic =>
+ failAnalysis(
+ s"nondeterministic expression ${expr.prettyString} should not " +
+ s"appear in the arguments of an aggregate function.")
+ case child => // OK
+ }
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
@@ -120,14 +133,26 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match {
- case BinaryType =>
- failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case m: MapType =>
- failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case _ => // OK
+ def checkValidGroupingExprs(expr: Expression): Unit = {
+ expr.dataType match {
+ case BinaryType =>
+ failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case a: ArrayType =>
+ failAnalysis(s"array type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case m: MapType =>
+ failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case _ => // OK
+ }
+ if (!expr.deterministic) {
+ // This is just a sanity check, our analysis rule PullOutNondeterministic should
+ // already pull out those nondeterministic expressions and evaluate them in
+ // a Project node.
+ failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
+ s"appear in grouping expression.")
+ }
}
aggregateExprs.foreach(checkValidAggregateExpression)
@@ -179,7 +204,8 @@ trait CheckAnalysis {
s"unresolved operator ${operator.simpleString}")
case o if o.expressions.exists(!_.deterministic) &&
- !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
+ !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] =>
+ // The rule above is used to check Aggregate operator.
failAnalysis(
s"""nondeterministic expressions are only allowed in Project or Filter, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 9b22ce2619..397eff0568 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -15,215 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions.aggregate
+package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.IntegerType
/**
- * Utility functions used by the query planner to convert our plan to new aggregation code path.
- */
-object Utils {
-
- // Check if the DataType given cannot be part of a group by clause.
- private def isUnGroupable(dt: DataType): Boolean = dt match {
- case _: ArrayType | _: MapType => true
- case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
- case _ => false
- }
-
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
- !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))
-
- private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate if supportsGroupingKeySchema(p) =>
-
- val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
- case expressions.Average(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Average(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Count(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.CountDistinct(children) =>
- val child = if (children.size > 1) {
- DropAnyNull(CreateStruct(children))
- } else {
- children.head
- }
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.First(child, ignoreNulls) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child, ignoreNulls),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Kurtosis(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Kurtosis(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Last(child, ignoreNulls) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child, ignoreNulls),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Max(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Max(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Min(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Min(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Skewness(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Skewness(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.StddevPop(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.StddevPop(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.StddevSamp(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.StddevSamp(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Sum(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.SumDistinct(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.Corr(left, right) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Corr(left, right),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.ApproxCountDistinct(child, rsd) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.VariancePop(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.VariancePop(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.VarianceSamp(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.VarianceSamp(child),
- mode = aggregate.Complete,
- isDistinct = false)
- })
-
- // Check if there is any expressions.AggregateExpression1 left.
- // If so, we cannot convert this plan.
- val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
- // For every expressions, check if it contains AggregateExpression1.
- expr.find {
- case agg: expressions.AggregateExpression1 => true
- case other => false
- }.isDefined
- }
-
- // Check if there are multiple distinct columns.
- // TODO remove this.
- val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.toSet.toSeq
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val hasMultipleDistinctColumnSets =
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- true
- } else {
- false
- }
-
- if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
- case other => None
- }
-
- def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
- // If the plan cannot be converted, we will do a final round check to see if the original
- // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
- // we need to throw an exception.
- val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg.aggregateFunction
- }
- }.distinct
- if (aggregateFunction2s.nonEmpty) {
- // For functions implemented based on the new interface, prepare a list of function names.
- val invalidFunctions = {
- if (aggregateFunction2s.length > 1) {
- s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
- s"and ${aggregateFunction2s.head.nodeName} are"
- } else {
- s"${aggregateFunction2s.head.nodeName} is"
- }
- }
- val errorMessage =
- s"${invalidFunctions} implemented based on the new Aggregate Function " +
- s"interface and it cannot be used with functions implemented based on " +
- s"the old Aggregate Function interface."
- throw new AnalysisException(errorMessage)
- }
- }
-
- def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate =>
- val converted = doConvert(p)
- if (converted.isDefined) {
- converted
- } else {
- checkInvalidAggregateFunction2(p)
- None
- }
- case other => None
- }
-}
-
-/**
- * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
+ * This rule rewrites an aggregate query with distinct aggregations into an expanded double
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
@@ -298,9 +100,11 @@ object Utils {
* we could improve this in the current rule by applying more advanced expression cannocalization
* techniques.
*/
-object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.resolved => p
+ // We need to wait until this Aggregate operator is resolved.
case a: Aggregate => rewrite(a)
case p => p
}
@@ -310,7 +114,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
- case ae: AggregateExpression2 => ae
+ case ae: AggregateExpression => ae
}
}
@@ -319,8 +123,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
- // Only continue to rewrite if there is more than one distinct group.
- if (distinctAggGroups.size > 1) {
+ val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
+ // When the flag is set to specialize single distinct agg planning,
+ // we will rely on our Aggregation strategy to handle queries with a single
+ // distinct column and this aggregate operator does have grouping expressions.
+ distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty)
+ } else {
+ distinctAggGroups.size >= 1
+ }
+ if (shouldRewrite) {
// Create the attributes for the grouping id and the group by clause.
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
@@ -332,11 +143,11 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
- af: AggregateFunction2)(
- attrs: Expression => Expression): AggregateFunction2 = {
+ af: AggregateFunction)(
+ attrs: Expression => Expression): AggregateFunction = {
af.withNewChildren(af.children.map {
case afc => attrs(afc)
- }).asInstanceOf[AggregateFunction2]
+ }).asInstanceOf[AggregateFunction]
}
// Setup unique distinct aggregate children.
@@ -381,7 +192,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
// Select the result of the first aggregate in the last aggregate.
- val result = AggregateExpression2(
+ val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d4334d1628..dfa749d1af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
@@ -177,6 +178,7 @@ object FunctionRegistry {
expression[ToRadians]("radians"),
// aggregate functions
+ expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 84e2b1366f..bf2bff0243 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -295,14 +296,17 @@ object HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
- case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
- case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
- case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
- case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
- case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
+ case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
}
}
@@ -562,12 +566,6 @@ object HiveTypeCoercion {
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
- case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest.
- case SumDistinct(e @ IntegralType()) if e.dataType != LongType =>
- SumDistinct(Cast(e, LongType))
- case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType =>
- SumDistinct(Cast(e, DoubleType))
-
case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eae17c86dd..6485bdfb30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -141,6 +141,10 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
+ override def prettyString: String = {
+ s"${name}(${children.map(_.prettyString).mkString(",")})"
+ }
+
override def toString: String = s"'$name(${children.mkString(",")})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index d8df66430a..af594c25c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -23,6 +23,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.types._
@@ -144,17 +145,18 @@ package object dsl {
}
}
- def sum(e: Expression): Expression = Sum(e)
- def sumDistinct(e: Expression): Expression = SumDistinct(e)
- def count(e: Expression): Expression = Count(e)
- def countDistinct(e: Expression*): Expression = CountDistinct(e)
+ def sum(e: Expression): Expression = Sum(e).toAggregateExpression()
+ def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true)
+ def count(e: Expression): Expression = Count(e).toAggregateExpression()
+ def countDistinct(e: Expression*): Expression =
+ Count(e).toAggregateExpression(isDistinct = true)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
- ApproxCountDistinct(e, rsd)
- def avg(e: Expression): Expression = Average(e)
- def first(e: Expression): Expression = First(e)
- def last(e: Expression): Expression = Last(e)
- def min(e: Expression): Expression = Min(e)
- def max(e: Expression): Expression = Max(e)
+ HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
+ def avg(e: Expression): Expression = Average(e).toAggregateExpression()
+ def first(e: Expression): Expression = new First(e).toAggregateExpression()
+ def last(e: Expression): Expression = new Last(e).toAggregateExpression()
+ def min(e: Expression): Expression = Min(e).toAggregateExpression()
+ def max(e: Expression): Expression = Max(e).toAggregateExpression()
def upper(e: Expression): Expression = Upper(e)
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c8c20ada5f..7f9e503470 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Average(child: Expression) extends DeclarativeAggregate {
@@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select avg(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
- private val resultType = child.dataType match {
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
+
+ private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}
- private val sumDataType = child.dataType match {
+ private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}
- private val sum = AttributeReference("sum", sumDataType)()
- private val count = AttributeReference("count", LongType)()
+ private lazy val sum = AttributeReference("sum", sumDataType)()
+ private lazy val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = sum :: count :: Nil
+ override lazy val aggBufferAttributes = sum :: count :: Nil
- override val initialValues = Seq(
+ override lazy val initialValues = Seq(
/* sum = */ Cast(Literal(0), sumDataType),
/* count = */ Literal(0L)
)
- override val updateExpressions = Seq(
+ override lazy val updateExpressions = Seq(
/* sum = */
Add(
sum,
@@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate {
/* count = */ If(IsNull(child), count, count + 1L)
)
- override val mergeExpressions = Seq(
+ override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
// If all input are nulls, count will be 0 and we will get null after the division.
- override val evaluateExpression = child.dataType match {
+ override lazy val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index ef08b025ff..984ce7f24d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
override def dataType: DataType = DoubleType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select avg(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 832338378f..00d7436b71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -35,6 +37,9 @@ case class Corr(
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate {
+ def this(left: Expression, right: Expression) =
+ this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def children: Seq[Expression] = Seq(left, right)
override def nullable: Boolean = false
@@ -43,6 +48,16 @@ case class Corr(
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"corr requires that both arguments are double type, " +
+ s"not (${left.dataType}, ${right.dataType}).")
+ }
+ }
+
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def inputAggBufferAttributes: Seq[AttributeReference] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index ec0c8b483a..09a1da9200 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val count = AttributeReference("count", LongType)()
+ private lazy val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = count :: Nil
+ override lazy val aggBufferAttributes = count :: Nil
- override val initialValues = Seq(
+ override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)
- override val updateExpressions = Seq(
+ override lazy val updateExpressions = Seq(
/* count = */ If(IsNull(child), count, count + 1L)
)
- override val mergeExpressions = Seq(
+ override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
- override val evaluateExpression = Cast(count, LongType)
+ override lazy val evaluateExpression = Cast(count, LongType)
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
+
+object Count {
+ def apply(children: Seq[Expression]): Count = {
+ // This is used to deal with COUNT DISTINCT. When we have multiple
+ // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
+ // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
+ // null in the arguments, we will not count that row. So, we use DropAnyNull at here
+ // to return a null when any field of the created STRUCT is null.
+ val child = if (children.size > 1) {
+ DropAnyNull(CreateStruct(children))
+ } else {
+ children.head
+ }
+ Count(child)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 9028143015..35f57426fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val first = AttributeReference("first", child.dataType)()
+ private lazy val first = AttributeReference("first", child.dataType)()
- private val valueSet = AttributeReference("valueSet", BooleanType)()
+ private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
- override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val initialValues: Seq[Literal] = Seq(
/* first = */ Literal.create(null, child.dataType),
/* valueSet = */ Literal.create(false, BooleanType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* first = */ If(Or(valueSet, IsNull(child)), first, child),
@@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
}
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
// For first, we can just check if valueSet.left is set to true. If it is set
// to true, we use first.right. If not, we use first.right (even if valueSet.right is
// false, we are safe to do so because first.right will be null in this case).
@@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
)
}
- override val evaluateExpression: AttributeReference = first
+ override lazy val evaluateExpression: AttributeReference = first
override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 8d341ee630..8a95c541f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -22,6 +22,7 @@ import java.util
import com.clearspring.analytics.hash.MurmurHash
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus(
extends ImperativeAggregate {
import HyperLogLogPlusPlus._
+ def this(child: Expression) = {
+ this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+ }
+
+ def this(child: Expression, relativeSD: Expression) = {
+ this(
+ child = child,
+ relativeSD = relativeSD match {
+ case Literal(d: Double, DoubleType) => d
+ case _ =>
+ throw new AnalysisException("The second argument should be a double literal.")
+ },
+ mutableAggBufferOffset = 0,
+ inputAggBufferOffset = 0)
+ }
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 6da39e7143..bae78d9849 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -24,6 +24,8 @@ case class Kurtosis(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index 8636bfe8d0..be7e12d7a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val last = AttributeReference("last", child.dataType)()
+ private lazy val last = AttributeReference("last", child.dataType)()
- override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val initialValues: Seq[Literal] = Seq(
/* last = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(child), last, child)
@@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
}
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
@@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
}
}
- override val evaluateExpression: AttributeReference = last
+ override lazy val evaluateExpression: AttributeReference = last
override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index b9d75ad452..61cae44cd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Max(child: Expression) extends DeclarativeAggregate {
@@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val max = AttributeReference("max", child.dataType)()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function max")
- override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+ private lazy val max = AttributeReference("max", child.dataType)()
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+
+ override lazy val initialValues: Seq[Literal] = Seq(
/* max = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val greatest = Greatest(Seq(max.left, max.right))
Seq(
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
)
}
- override val evaluateExpression: AttributeReference = max
+ override lazy val evaluateExpression: AttributeReference = max
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 5ed9cd348d..242456d9e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val min = AttributeReference("min", child.dataType)()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function min")
- override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+ private lazy val min = AttributeReference("min", child.dataType)()
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+
+ override lazy val initialValues: Seq[Expression] = Seq(
/* min = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val least = Least(Seq(min.left, min.right))
Seq(
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
)
}
- override val evaluateExpression: AttributeReference = min
+ override lazy val evaluateExpression: AttributeReference = min
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index 0def7ddfd9..c593074fa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -24,6 +24,8 @@ case class Skewness(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 3f47ffe13c..5b9eb7ae02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select stddev(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
- private val resultType = DoubleType
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
- private val count = AttributeReference("count", resultType)()
- private val avg = AttributeReference("avg", resultType)()
- private val mk = AttributeReference("mk", resultType)()
+ private lazy val resultType = DoubleType
- override val aggBufferAttributes = count :: avg :: mk :: Nil
+ private lazy val count = AttributeReference("count", resultType)()
+ private lazy val avg = AttributeReference("avg", resultType)()
+ private lazy val mk = AttributeReference("mk", resultType)()
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
+
+ override lazy val initialValues: Seq[Expression] = Seq(
/* count = */ Cast(Literal(0), resultType),
/* avg = */ Cast(Literal(0), resultType),
/* mk = */ Cast(Literal(0), resultType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
@@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
)
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
// count merge
val newCount = count.left + count.right
@@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
)
}
- override val evaluateExpression: Expression = {
+ override lazy val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 7f8adbc56a..c005ec9657 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Sum(child: Expression) extends DeclarativeAggregate {
@@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select sum(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
- private val resultType = child.dataType match {
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sum")
+
+ private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
// TODO: Remove this line once we remove the NullType from inputTypes.
@@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
case _ => child.dataType
}
- private val sumDataType = resultType
+ private lazy val sumDataType = resultType
- private val sum = AttributeReference("sum", sumDataType)()
+ private lazy val sum = AttributeReference("sum", sumDataType)()
- private val zero = Cast(Literal(0), sumDataType)
+ private lazy val zero = Cast(Literal(0), sumDataType)
- override val aggBufferAttributes = sum :: Nil
+ override lazy val aggBufferAttributes = sum :: Nil
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* sum = */
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
@@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
)
}
- override val evaluateExpression: Expression = Cast(sum, resultType)
+ override lazy val evaluateExpression: Expression = Cast(sum, resultType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ec63534e52..ede2da2805 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression,
}
}
-case class VariancePop(child: Expression,
+case class VariancePop(
+ child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 5c5b3d1ccd..3b441de34a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,23 +17,24 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-/** The mode of an [[AggregateFunction2]]. */
+/** The mode of an [[AggregateFunction]]. */
private[sql] sealed trait AggregateMode
/**
- * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation.
+ * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object Partial extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
@@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
@@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
private[sql] case object Final extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
+ * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
@@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable {
}
/**
- * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
*/
-private[sql] case class AggregateExpression2(
- aggregateFunction: AggregateFunction2,
+private[sql] case class AggregateExpression(
+ aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean) extends AggregateExpression {
+ isDistinct: Boolean)
+ extends Expression
+ with Unevaluable {
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
@@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2(
AttributeSet(childReferences)
}
+ override def prettyString: String = aggregateFunction.prettyString
+
override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
}
@@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2(
* combined aggregation buffer which concatenates the aggregation buffers of the individual
* aggregate functions.
*
- * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of
+ * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
* aggregate functions.
*/
-sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes {
+sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
@@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ /**
+ * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because
+ * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+ * and the flag indicating if this aggregation is distinct aggregation or not.
+ * An [[AggregateFunction]] should not be used without being wrapped in
+ * an [[AggregateExpression]].
+ */
+ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)
+
+ /**
+ * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct
+ * field of the [[AggregateExpression]] to the given value because
+ * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+ * and the flag indicating if this aggregation is distinct aggregation or not.
+ * An [[AggregateFunction]] should not be used without being wrapped in
+ * an [[AggregateExpression]].
+ */
+ def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
+ }
}
/**
@@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
* `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
* and `inputAggBufferAttributes`.
*/
-abstract class ImperativeAggregate extends AggregateFunction2 {
+abstract class ImperativeAggregate extends AggregateFunction {
/**
* The offset of this function's first buffer value in the underlying shared mutable aggregation
@@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You
* can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and
* `evaluateExpressions`.
+ *
+ * Please note that children of an aggregate function can be unresolved (it will happen when
+ * we create this function in DataFrame API). So, if there is any fields in
+ * the implemented class that need to access fields of its children, please make
+ * those fields `lazy val`s.
*/
abstract class DeclarativeAggregate
- extends AggregateFunction2
+ extends AggregateFunction
with Serializable
with Unevaluable {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
deleted file mode 100644
index 3dcf7915d7..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ /dev/null
@@ -1,1073 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils}
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
-
-
-trait AggregateExpression extends Expression with Unevaluable
-
-trait AggregateExpression1 extends AggregateExpression {
-
- /**
- * Aggregate expressions should not be foldable.
- */
- override def foldable: Boolean = false
-
- /**
- * Creates a new instance that can be used to compute this aggregate expression for a group
- * of input rows/
- */
- def newInstance(): AggregateFunction1
-}
-
-/**
- * Represents an aggregation that has been rewritten to be performed in two steps.
- *
- * @param finalEvaluation an aggregate expression that evaluates to same final result as the
- * original aggregation.
- * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial
- * data sets and are required to compute the `finalEvaluation`.
- */
-case class SplitEvaluation(
- finalEvaluation: Expression,
- partialEvaluations: Seq[NamedExpression])
-
-/**
- * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
- * These partial evaluations can then be combined to compute the actual answer.
- */
-trait PartialAggregate1 extends AggregateExpression1 {
-
- /**
- * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
- */
- def asPartial: SplitEvaluation
-}
-
-/**
- * A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
- */
-abstract class AggregateFunction1 extends LeafExpression with Serializable {
-
- /** Base should return the generic aggregate expression that this function is computing */
- val base: AggregateExpression1
-
- override def nullable: Boolean = base.nullable
- override def dataType: DataType = base.dataType
-
- def update(input: InternalRow): Unit
-
- override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- throw new UnsupportedOperationException(
- "AggregateFunction1 should not be used for generated aggregates")
- }
-}
-
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
-
- override def asPartial: SplitEvaluation = {
- val partialMin = Alias(Min(child), "PartialMin")()
- SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
- }
-
- override def newInstance(): MinFunction = new MinFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForOrderingExpr(child.dataType, "function min")
-}
-
-case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
- val cmp = GreaterThan(currentMin, expr)
-
- override def update(input: InternalRow): Unit = {
- if (currentMin.value == null) {
- currentMin.value = expr.eval(input)
- } else if (cmp.eval(input) == true) {
- currentMin.value = expr.eval(input)
- }
- }
-
- override def eval(input: InternalRow): Any = currentMin.value
-}
-
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
-
- override def asPartial: SplitEvaluation = {
- val partialMax = Alias(Max(child), "PartialMax")()
- SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
- }
-
- override def newInstance(): MaxFunction = new MaxFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForOrderingExpr(child.dataType, "function max")
-}
-
-case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
- val cmp = LessThan(currentMax, expr)
-
- override def update(input: InternalRow): Unit = {
- if (currentMax.value == null) {
- currentMax.value = expr.eval(input)
- } else if (cmp.eval(input) == true) {
- currentMax.value = expr.eval(input)
- }
- }
-
- override def eval(input: InternalRow): Any = currentMax.value
-}
-
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
-
- override def asPartial: SplitEvaluation = {
- val partialCount = Alias(Count(child), "PartialCount")()
- SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
- }
-
- override def newInstance(): CountFunction = new CountFunction(child, this)
-}
-
-case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- var count: Long = _
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1L
- }
- }
-
- override def eval(input: InternalRow): Any = count
-}
-
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = expressions
-
- override def nullable: Boolean = false
- override def dataType: DataType = LongType
- override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})"
- override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
- SplitEvaluation(
- CombineSetsAndCount(partialSet.toAttribute),
- partialSet :: Nil)
- }
-}
-
-case class CountDistinctFunction(
- @transient expr: Seq[Expression],
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = expressions
- override def nullable: Boolean = false
- override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType)
- override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance(): CollectHashSetFunction =
- new CollectHashSetFunction(expressions, this)
-}
-
-case class CollectHashSetFunction(
- @transient expr: Seq[Expression],
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- seen
- }
-}
-
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = LongType
- override def toString: String = s"CombineAndCount($inputSet)"
- override def newInstance(): CombineSetsAndCountFunction = {
- new CombineSetsAndCountFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndCountFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next)
- }
- }
-
- override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */
-private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
-
- override def sqlType: DataType = BinaryType
-
- /** Since we are using HyperLogLog internally, usually it will not be called. */
- override def serialize(obj: Any): Array[Byte] =
- obj.asInstanceOf[HyperLogLog].getBytes
-
-
- /** Since we are using HyperLogLog internally, usually it will not be called. */
- override def deserialize(datum: Any): HyperLogLog =
- HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
-
- override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
-}
-
-case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression1 {
-
- override def nullable: Boolean = false
- override def dataType: DataType = HyperLogLogUDT
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance(): ApproxCountDistinctPartitionFunction = {
- new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
- }
-}
-
-case class ApproxCountDistinctPartitionFunction(
- expr: Expression,
- base: AggregateExpression1,
- relativeSD: Double)
- extends AggregateFunction1 {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- hyperLogLog.offer(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = hyperLogLog
-}
-
-case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance(): ApproxCountDistinctMergeFunction = {
- new ApproxCountDistinctMergeFunction(child, this, relativeSD)
- }
-}
-
-case class ApproxCountDistinctMergeFunction(
- expr: Expression,
- base: AggregateExpression1,
- relativeSD: Double)
- extends AggregateFunction1 {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
- }
-
- override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
-}
-
-case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
- extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
-
- override def asPartial: SplitEvaluation = {
- val partialCount =
- Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
-
- SplitEvaluation(
- ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
- partialCount :: Nil)
- }
-
- override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
-}
-
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def prettyName: String = "avg"
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 4 digits after decimal point, like Hive
- DecimalType.bounded(precision + 4, scale + 4)
- case _ =>
- DoubleType
- }
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
- // partialSum already increase the precision by 10
- val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
- SplitEvaluation(
- Cast(Divide(castedSum, castedCount), dataType),
- partialCount :: partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
- val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
- SplitEvaluation(
- Divide(castedSum, castedCount),
- partialCount :: partialSum :: Nil)
- }
- }
-
- override def newInstance(): AverageFunction = new AverageFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function average")
-}
-
-case class AverageFunction(expr: Expression, base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- private val calcType =
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- expr.dataType
- }
-
- private val zero = Cast(Literal(0), calcType)
-
- private var count: Long = _
- private val sum = MutableLiteral(zero.eval(null), calcType)
-
- private def addFunction(value: Any) = Add(sum,
- Cast(Literal.create(value, expr.dataType), calcType))
-
- override def eval(input: InternalRow): Any = {
- if (count == 0L) {
- null
- } else {
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- val dt = DecimalType.bounded(precision + 14, scale + 4)
- Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
- case _ =>
- Divide(
- Cast(sum, dataType),
- Cast(Literal(count), dataType)).eval(null)
- }
- }
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1
- sum.update(addFunction(evaluatedExpr), input)
- }
- }
-}
-
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 10 digits left of decimal point, like Hive
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- child.dataType
- }
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(_, _) =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Cast(Sum(partialSum.toAttribute), dataType),
- partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Sum(partialSum.toAttribute),
- partialSum :: Nil)
- }
- }
-
- override def newInstance(): SumFunction = new SumFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sum")
-}
-
-case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- private val calcType =
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- expr.dataType
- }
-
- private val zero = Cast(Literal(0), calcType)
-
- private val sum = MutableLiteral(null, calcType)
-
- private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
-
- override def update(input: InternalRow): Unit = {
- sum.update(addFunction, input)
- }
-
- override def eval(input: InternalRow): Any = {
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(sum, dataType).eval(null)
- case _ => sum.eval(null)
- }
- }
-}
-
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- def this() = this(null)
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 10 digits left of decimal point, like Hive
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- child.dataType
- }
- override def toString: String = s"sum(distinct $child)"
- override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
- SplitEvaluation(
- CombineSetsAndSum(partialSet.toAttribute, this),
- partialSet :: Nil)
- }
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
-}
-
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- private val seen = new scala.collection.mutable.HashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- seen += evaluatedExpr
- }
- }
-
- override def eval(input: InternalRow): Any = {
- if (seen.size == 0) {
- null
- } else {
- Cast(Literal(
- seen.reduceLeft(
- dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- dataType).eval(null)
- }
- }
-}
-
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
- def this() = this(null, null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = base.dataType
- override def toString: String = s"CombineAndSum($inputSet)"
- override def newInstance(): CombineSetsAndSumFunction = {
- new CombineSetsAndSumFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndSumFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next())
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
- if (casted.size == 0) {
- null
- } else {
- Cast(Literal(
- casted.iterator.map(f => f.get(0, null)).reduceLeft(
- base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- base.dataType).eval(null)
- }
- }
-}
-
-case class First(
- child: Expression,
- ignoreNullsExpr: Expression)
- extends UnaryExpression with PartialAggregate1 {
-
- def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})"
-
- override def asPartial: SplitEvaluation = {
- val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
- SplitEvaluation(
- First(partialFirst.toAttribute, ignoreNulls),
- partialFirst :: Nil)
- }
- override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
-}
-
-object First {
- def apply(child: Expression): First = First(child, ignoreNulls = false)
-
- def apply(child: Expression, ignoreNulls: Boolean): First =
- First(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class FirstFunction(
- expr: Expression,
- ignoreNulls: Boolean,
- base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
- private[this] var result: Any = null
-
- private[this] var valueSet: Boolean = false
-
- override def update(input: InternalRow): Unit = {
- if (!valueSet) {
- val value = expr.eval(input)
- // When we have not set the result, we will set the result if we respect nulls
- // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
- if (!ignoreNulls || (ignoreNulls && value != null)) {
- result = value
- valueSet = true
- }
- }
- }
-
- override def eval(input: InternalRow): Any = result
-}
-
-case class Last(
- child: Expression,
- ignoreNullsExpr: Expression)
- extends UnaryExpression with PartialAggregate1 {
-
- def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
- override def references: AttributeSet = child.references
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
-
- override def asPartial: SplitEvaluation = {
- val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
- SplitEvaluation(
- Last(partialLast.toAttribute, ignoreNulls),
- partialLast :: Nil)
- }
- override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
-}
-
-object Last {
- def apply(child: Expression): Last = Last(child, ignoreNulls = false)
-
- def apply(child: Expression, ignoreNulls: Boolean): Last =
- Last(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class LastFunction(
- expr: Expression,
- ignoreNulls: Boolean,
- base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
- var result: Any = null
-
- override def update(input: InternalRow): Unit = {
- val value = expr.eval(input)
- if (!ignoreNulls || (ignoreNulls && value != null)) {
- result = value
- }
- }
-
- override def eval(input: InternalRow): Any = {
- result
- }
-}
-
-/**
- * Calculate Pearson Correlation Coefficient for the given columns.
- * Only support AggregateExpression2.
- *
- */
-case class Corr(left: Expression, right: Expression)
- extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
- override def nullable: Boolean = false
- override def dataType: DoubleType.type = DoubleType
- override def toString: String = s"corr($left, $right)"
- override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException(
- "Corr only supports the new AggregateExpression2 and can only be used " +
- "when spark.sql.useAggregate2 = true")
- }
-}
-
-// Compute standard deviation based on online algorithm specified here:
-// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
- override def nullable: Boolean = true
- override def dataType: DataType = DoubleType
-
- def isSample: Boolean
-
- override def asPartial: SplitEvaluation = {
- val partialStd = Alias(ComputePartialStd(child), "PartialStddev")()
- SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil)
- }
-
- override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
-
-}
-
-// Compute the population standard deviation of a column
-case class StddevPop(child: Expression) extends StddevAgg1(child) {
-
- override def toString: String = s"stddev_pop($child)"
- override def isSample: Boolean = false
-}
-
-// Compute the sample standard deviation of a column
-case class StddevSamp(child: Expression) extends StddevAgg1(child) {
-
- override def toString: String = s"stddev_samp($child)"
- override def isSample: Boolean = true
-}
-
-case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = child :: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = ArrayType(DoubleType)
- override def toString: String = s"computePartialStddev($child)"
- override def newInstance(): ComputePartialStdFunction =
- new ComputePartialStdFunction(child, this)
-}
-
-case class ComputePartialStdFunction (
- expr: Expression,
- base: AggregateExpression1
- ) extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization
-
- private val computeType = DoubleType
- private val zero = Cast(Literal(0), computeType)
- private var partialCount: Long = 0L
-
- // the mean of data processed so far
- private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
- // update average based on this formula:
- // avg = avg + (value - avg)/count
- private def avgAddFunction (value: Literal): Expression = {
- val delta = Subtract(Cast(value, computeType), partialAvg)
- Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType)))
- }
-
- // the sum of squares of difference from mean
- private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
- // update sum of square of difference from mean based on following formula:
- // Mk = Mk + (value - preAvg) * (value - updatedAvg)
- private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = {
- val delta1 = Subtract(Cast(value, computeType), prePartialAvg)
- val delta2 = Subtract(Cast(value, computeType), partialAvg)
- Add(partialMk, Multiply(delta1, delta2))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- val exprValue = Literal.create(evaluatedExpr, expr.dataType)
- val prePartialAvg = partialAvg.copy()
- partialCount += 1
- partialAvg.update(avgAddFunction(exprValue), input)
- partialMk.update(mkAddFunction(exprValue, prePartialAvg), input)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null),
- partialAvg.eval(null),
- partialMk.eval(null)))
- }
-}
-
-case class MergePartialStd(
- child: Expression,
- isSample: Boolean
-) extends UnaryExpression with AggregateExpression1 {
- def this() = this(null, false) // required for serialization
-
- override def children: Seq[Expression] = child:: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = DoubleType
- override def toString: String = s"MergePartialStd($child)"
- override def newInstance(): MergePartialStdFunction = {
- new MergePartialStdFunction(child, this, isSample)
- }
-}
-
-case class MergePartialStdFunction(
- expr: Expression,
- base: AggregateExpression1,
- isSample: Boolean
-) extends AggregateFunction1 {
- def this() = this (null, null, false) // Required for serialization
-
- private val computeType = DoubleType
- private val zero = Cast(Literal(0), computeType)
- private val combineCount = MutableLiteral(zero.eval(null), computeType)
- private val combineAvg = MutableLiteral(zero.eval(null), computeType)
- private val combineMk = MutableLiteral(zero.eval(null), computeType)
-
- private def avgUpdateFunction(preCount: Expression,
- partialCount: Expression,
- partialAvg: Expression): Expression = {
- Divide(Add(Multiply(combineAvg, preCount),
- Multiply(partialAvg, partialCount)),
- Add(preCount, partialCount))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData]
-
- if (evaluatedExpr != null) {
- val exprValue = evaluatedExpr.toArray(computeType)
- val (partialCount, partialAvg, partialMk) =
- (Literal.create(exprValue(0), computeType),
- Literal.create(exprValue(1), computeType),
- Literal.create(exprValue(2), computeType))
-
- if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) {
- val preCount = combineCount.copy()
- combineCount.update(Add(combineCount, partialCount), input)
-
- val preAvg = combineAvg.copy()
- val avgDelta = Subtract(partialAvg, preAvg)
- val mkDelta = Multiply(Multiply(avgDelta, avgDelta),
- Divide(Multiply(preCount, partialCount),
- combineCount))
-
- // update average based on following formula
- // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount)
- combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input)
-
- // update sum of square differences from mean based on following formula
- // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount)
- combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input)
- }
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long]
-
- if (count == 0) null
- else if (count < 2) zero.eval(null)
- else {
- // when total count > 2
- // stddev_samp = sqrt (combineMk/(combineCount -1))
- // stddev_pop = sqrt (combineMk/combineCount)
- val varCol = {
- if (isSample) {
- Divide(combineMk, Cast(Literal(count - 1), computeType))
- }
- else {
- Divide(combineMk, Cast(Literal(count), computeType))
- }
- }
- Sqrt(varCol).eval(null)
- }
- }
-}
-
-case class StddevFunction(
- expr: Expression,
- base: AggregateExpression1,
- isSample: Boolean
-) extends AggregateFunction1 {
-
- def this() = this(null, null, false) // Required for serialization
-
- private val computeType = DoubleType
- private var curCount: Long = 0L
- private val zero = Cast(Literal(0), computeType)
- private val curAvg = MutableLiteral(zero.eval(null), computeType)
- private val curMk = MutableLiteral(zero.eval(null), computeType)
-
- private def curAvgAddFunction(value: Literal): Expression = {
- val delta = Subtract(Cast(value, computeType), curAvg)
- Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType)))
- }
- private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = {
- val delta1 = Subtract(Cast(value, computeType), preAvg)
- val delta2 = Subtract(Cast(value, computeType), curAvg)
- Add(curMk, Multiply(delta1, delta2))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- val preAvg: MutableLiteral = curAvg.copy()
- val exprValue = Literal.create(evaluatedExpr, expr.dataType)
- curCount += 1L
- curAvg.update(curAvgAddFunction(exprValue), input)
- curMk.update(curMkAddFunction(exprValue, preAvg), input)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- if (curCount == 0) null
- else if (curCount < 2) zero.eval(null)
- else {
- // when total count > 2,
- // stddev_samp = sqrt(curMk/(curCount - 1))
- // stddev_pop = sqrt(curMk/curCount)
- val varCol = {
- if (isSample) {
- Divide(curMk, Cast(Literal(curCount - 1), computeType))
- }
- else {
- Divide(curMk, Cast(Literal(curCount), computeType))
- }
- }
- Sqrt(varCol).eval(null)
- }
- }
-}
-
-// placeholder
-case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "kurtosis"
-}
-
-// placeholder
-case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "skewness"
-}
-
-// placeholder
-case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "var_pop"
-}
-
-// placeholder
-case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "var_samp"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d222dfa33a..f4dba67f13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.LeftOuter
@@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, _, child))
- if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
+ if (child.outputSet -- e.references -- a.references).nonEmpty =>
+ a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
@@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+ case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+ Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
@@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ Count(expr) if !expr.nullable => Count(Literal(1))
+ case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+ // This rule should be only triggered when isDistinct field is false.
+ AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+ case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
- case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
Cast(
- Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)),
+ Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b975b904a..6f4f11406d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -85,80 +85,6 @@ object PhysicalOperation extends PredicateHelper {
}
/**
- * Matches a logical aggregation that can be performed on distributed data in two steps. The first
- * operates on the data in each partition performing partial aggregation for each group. The second
- * occurs after the shuffle and completes the aggregation.
- *
- * This pattern will only match if all aggregate expressions can be computed partially and will
- * return the rewritten aggregation expressions for both phases.
- *
- * The returned values for this match are as follows:
- * - Grouping attributes for the final aggregation.
- * - Aggregates for the final aggregation.
- * - Grouping expressions for the partial aggregation.
- * - Partial aggregate expressions.
- * - Input to the aggregation.
- */
-object PartialAggregation {
- type ReturnType =
- (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
-
- def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
- // Collect all aggregate expressions.
- val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
- // Collect all aggregate expressions that can be computed partially.
- val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
-
- // Only do partial aggregation if supported by all aggregate expressions.
- if (allAggregates.size == partialAggregates.size) {
- // Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
- partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
-
- // We need to pass all grouping expressions though so the grouping can happen a second
- // time. However some of them might be unnamed so we alias them allowing them to be
- // referenced in the second aggregation.
- val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
- groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }
-
- // Replace aggregations with a new expression that computes the result from the already
- // computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
- case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
- partialEvaluations(new TreeNodeRef(e)).finalEvaluation
-
- case e: Expression =>
- namedGroupingExpressions.collectFirst {
- case (expr, ne) if expr semanticEquals e => ne.toAttribute
- }.getOrElse(e)
- }).asInstanceOf[Seq[NamedExpression]]
-
- val partialComputation = namedGroupingExpressions.map(_._2) ++
- partialEvaluations.values.flatMap(_.partialEvaluations)
-
- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-
- Some(
- (namedGroupingAttributes,
- rewrittenAggregateExpressions,
- groupingExpressions,
- partialComputation,
- child))
- } else {
- None
- }
- case _ => None
- }
-}
-
-
-/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*
* Null-safe equality will be transformed into equality as joining key (replace null with default
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0ec9f08571..b9db7838db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/** Returns all of the expressions present in this query plan operator. */
def expressions: Seq[Expression] = {
+ // Recursively find all expressions from a traversable.
+ def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap {
+ case e: Expression => e :: Nil
+ case s: Traversable[_] => seqToExpressions(s)
+ case other => Nil
+ }
+
productIterator.flatMap {
case e: Expression => e :: Nil
case Some(e: Expression) => e :: Nil
- case seq: Traversable[_] => seq.flatMap {
- case e: Expression => e :: Nil
- case other => Nil
- }
+ case seq: Traversable[_] => seqToExpressions(seq)
case other => Nil
}.toSeq
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d771088d69..764f8aaebd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -219,8 +219,6 @@ case class Aggregate(
!expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
}
- lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
-
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index fbdd3a7776..5a2368e329 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -171,16 +171,18 @@ class AnalysisErrorSuite extends AnalysisTest {
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+ // Since we manually construct the logical plan at here and Sum only accetp
+ // LongType, DoubleType, and DecimalType. We use LongType as the type of a.
val plan =
Aggregate(
Nil,
- Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
+ Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
- AttributeReference("a", IntegerType)(exprId = ExprId(2))))
+ AttributeReference("a", LongType)(exprId = ExprId(2))))
assert(plan.resolved)
- assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil)
+ assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil)
}
test("error test for self-join") {
@@ -196,7 +198,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan =
Aggregate(
AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
- Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
@@ -207,13 +209,24 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan2 =
Aggregate(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
- Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
assertAnalysisError(plan2,
"map type expression a cannot be used in grouping expression" :: Nil)
+
+ val plan3 =
+ Aggregate(
+ AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+ assertAnalysisError(plan3,
+ "array type expression a cannot be used in grouping expression" :: Nil)
}
test("Join can't work on binary and map types") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 71d2939ecf..65f09b46af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest {
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
- assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
+ assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved)
}
test("analyze project") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 40c4ae7920..fed591fd90 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index c9bcc68f02..b902982add 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{TypeCollection, StringType}
@@ -140,15 +141,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for aggregates") {
+ // We use AggregateFunction directly at here because the error will be thrown from it
+ // instead of from AggregateExpression, which is the wrapper of an AggregateFunction.
+
// We will cast String to Double for sum and average
assertSuccess(Sum('stringField))
- assertSuccess(SumDistinct('stringField))
assertSuccess(Average('stringField))
assertError(Min('complexField), "min does not support ordering on type")
assertError(Max('complexField), "max does not support ordering on type")
assertError(Sum('booleanField), "function sum requires numeric type")
- assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type")
assertError(Average('booleanField), "function average requires numeric type")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index e67606288f..8aaefa8493 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Rand(5L) + Literal(1) as Symbol("c1"),
- Sum('a) as Symbol("c2"))
+ sum('a) as Symbol("c2"))
val optimized = Optimize.execute(originalQuery.analyze)
@@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Rand(5L) + Literal(1.0) as Symbol("c1"),
- Sum('a) as Symbol("c2"))
+ sum('a) as Symbol("c2"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ed810a1280..0290fafe87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -68,7 +68,7 @@ class FilterPushdownSuite extends PlanTest {
test("column pruning for group") {
val originalQuery =
testRelation
- .groupBy('a)('a, Count('b))
+ .groupBy('a)('a, count('b))
.select('a)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -84,7 +84,7 @@ class FilterPushdownSuite extends PlanTest {
test("column pruning for group with alias") {
val originalQuery =
testRelation
- .groupBy('a)('a as 'c, Count('b))
+ .groupBy('a)('a as 'c, count('b))
.select('c)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -656,7 +656,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: push down filter when filter on group by expression") {
val originalQuery = testRelation
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.select('a, 'c)
.where('a === 2)
@@ -664,7 +664,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer = testRelation
.where('a === 2)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.analyze
comparePlans(optimized, correctAnswer)
}
@@ -672,7 +672,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: don't push down filter when filter not on group by expression") {
val originalQuery = testRelation
.select('a, 'b)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -683,7 +683,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: push down filters partially which are subset of group by expressions") {
val originalQuery = testRelation
.select('a, 'b)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L && 'a === 3)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -691,7 +691,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer = testRelation
.select('a, 'b)
.where('a === 3)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index d25807cf8d..3b69247dc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -1338,7 +1339,7 @@ class DataFrame private[sql](
if (groupColExprIds.contains(attr.exprId)) {
attr
} else {
- Alias(First(attr), attr.name)()
+ Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
Aggregate(groupCols, aggCols, logicalPlan)
@@ -1381,11 +1382,11 @@ class DataFrame private[sql](
// The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
- "count" -> Count,
- "mean" -> Average,
- "stddev" -> StddevSamp,
- "min" -> Min,
- "max" -> Max)
+ "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
+ "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
+ "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
+ "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
+ "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index f9eab5c2e9..5babf2cc0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -21,8 +21,9 @@ import scala.collection.JavaConverters._
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
@@ -70,7 +71,7 @@ class GroupedData protected[sql](
}
}
- private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
+ private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
val columnExprs = if (colNames.isEmpty) {
@@ -88,30 +89,28 @@ class GroupedData protected[sql](
namedExpr
}
}
- toDF(columnExprs.map(f))
+ toDF(columnExprs.map(expr => f(expr).toAggregateExpression()))
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
- expr.toLowerCase match {
- case "avg" | "average" | "mean" => Average
- case "max" => Max
- case "min" => Min
- case "stddev" | "std" => StddevSamp
- case "stddev_pop" => StddevPop
- case "stddev_samp" => StddevSamp
- case "variance" => VarianceSamp
- case "var_pop" => VariancePop
- case "var_samp" => VarianceSamp
- case "sum" => Sum
- case "skewness" => Skewness
- case "kurtosis" => Kurtosis
- case "count" | "size" =>
- // Turn count(*) into count(1)
- (inputExpr: Expression) => inputExpr match {
- case s: Star => Count(Literal(1))
- case _ => Count(inputExpr)
- }
+ val exprToFunc: (Expression => Expression) = {
+ (inputExpr: Expression) => expr.toLowerCase match {
+ // We special handle a few cases that have alias that are not in function registry.
+ case "avg" | "average" | "mean" =>
+ UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false)
+ case "stddev" | "std" =>
+ UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false)
+ // Also special handle count because we need to take care count(*).
+ case "count" | "size" =>
+ // Turn count(*) into count(1)
+ inputExpr match {
+ case s: Star => Count(Literal(1)).toAggregateExpression()
+ case _ => Count(inputExpr).toAggregateExpression()
+ }
+ case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false)
+ }
}
+ (inputExpr: Expression) => exprToFunc(inputExpr)
}
/**
@@ -213,7 +212,7 @@ class GroupedData protected[sql](
*
* @since 1.3.0
*/
- def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")()))
+ def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ed8b634ad5..b7314189b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -448,15 +448,24 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
- val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
- defaultValue = Some(true), doc = "<TODO>")
-
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
doc = "When true, we could use `datasource`.`path` as table in SQL query"
)
+ val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING =
+ booleanConf("spark.sql.specializeSingleDistinctAggPlanning",
+ defaultValue = Some(true),
+ isPublic = false,
+ doc = "When true, if a query only has a single distinct column and it has " +
+ "grouping expressions, we will use our planner rule to handle this distinct " +
+ "column (other cases are handled by DistinctAggregationRewriter). " +
+ "When false, we will always use DistinctAggregationRewriter to plan " +
+ "aggregation queries with DISTINCT keyword. This is an internal flag that is " +
+ "used to benchmark the performance impact of using DistinctAggregationRewriter to " +
+ "plan aggregation queries with a single distinct column.")
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -532,8 +541,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED))
- private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
-
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
@@ -575,6 +582,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
+ protected[spark] override def specializeSingleDistinctAggPlanning: Boolean =
+ getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
deleted file mode 100644
index 6f3f1bd97a..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.util.HashMap
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
- * group.
- *
- * @param partial if true then aggregation is done partially on local data without shuffling to
- * ensure all values where `groupingExpressions` are equal are present.
- * @param groupingExpressions expressions that are evaluated to determine grouping.
- * @param aggregateExpressions expressions that are computed for each group.
- * @param child the input data source.
- */
-case class Aggregate(
- partial: Boolean,
- groupingExpressions: Seq[Expression],
- aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
-
- override private[sql] lazy val metrics = Map(
- "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def requiredChildDistribution: List[Distribution] = {
- if (partial) {
- UnspecifiedDistribution :: Nil
- } else {
- if (groupingExpressions == Nil) {
- AllTuples :: Nil
- } else {
- ClusteredDistribution(groupingExpressions) :: Nil
- }
- }
- }
-
- override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
-
- /**
- * An aggregate that needs to be computed for each row in a group.
- *
- * @param unbound Unbound version of this aggregate, used for result substitution.
- * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer.
- * @param resultAttribute An attribute used to refer to the result of this aggregate in the final
- * output.
- */
- case class ComputedAggregate(
- unbound: AggregateExpression1,
- aggregate: AggregateExpression1,
- resultAttribute: AttributeReference)
-
- /** A list of aggregates that need to be computed for each group. */
- private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
- agg.collect {
- case a: AggregateExpression1 =>
- ComputedAggregate(
- a,
- BindReferences.bindReference(a, child.output),
- AttributeReference(s"aggResult:$a", a.dataType, a.nullable)())
- }
- }.toArray
-
- /** The schema of the result of all aggregate evaluations */
- private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
-
- /** Creates a new aggregate buffer for a group. */
- private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
- val buffer = new Array[AggregateFunction1](computedAggregates.length)
- var i = 0
- while (i < computedAggregates.length) {
- buffer(i) = computedAggregates(i).aggregate.newInstance()
- i += 1
- }
- buffer
- }
-
- /** Named attributes used to substitute grouping attributes into the final result. */
- private[this] val namedGroups = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne.toAttribute
- case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
- }
-
- /**
- * A map of substitutions that are used to insert the aggregate expressions and grouping
- * expression into the final result expression.
- */
- private[this] val resultMap =
- (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap
-
- /**
- * Substituted version of aggregateExpressions expressions which are used to compute final
- * output rows given a group and the result of all aggregate computations.
- */
- private[this] val resultExpressions = aggregateExpressions.map { agg =>
- agg.transform {
- case e: Expression if resultMap.contains(e) => resultMap(e)
- }
- }
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- val numInputRows = longMetric("numInputRows")
- val numOutputRows = longMetric("numOutputRows")
- if (groupingExpressions.isEmpty) {
- child.execute().mapPartitions { iter =>
- val buffer = newAggregateBuffer()
- var currentRow: InternalRow = null
- while (iter.hasNext) {
- currentRow = iter.next()
- numInputRows += 1
- var i = 0
- while (i < buffer.length) {
- buffer(i).update(currentRow)
- i += 1
- }
- }
- val resultProjection = new InterpretedProjection(resultExpressions, computedSchema)
- val aggregateResults = new GenericMutableRow(computedAggregates.length)
-
- var i = 0
- while (i < buffer.length) {
- aggregateResults(i) = buffer(i).eval(EmptyRow)
- i += 1
- }
-
- numOutputRows += 1
- Iterator(resultProjection(aggregateResults))
- }
- } else {
- child.execute().mapPartitions { iter =>
- val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
- val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
-
- var currentRow: InternalRow = null
- while (iter.hasNext) {
- currentRow = iter.next()
- numInputRows += 1
- val currentGroup = groupingProjection(currentRow)
- var currentBuffer = hashTable.get(currentGroup)
- if (currentBuffer == null) {
- currentBuffer = newAggregateBuffer()
- hashTable.put(currentGroup.copy(), currentBuffer)
- }
-
- var i = 0
- while (i < currentBuffer.length) {
- currentBuffer(i).update(currentRow)
- i += 1
- }
- }
-
- new Iterator[InternalRow] {
- private[this] val hashTableIter = hashTable.entrySet().iterator()
- private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
- private[this] val resultProjection =
- new InterpretedMutableProjection(
- resultExpressions, computedSchema ++ namedGroups.map(_._2))
- private[this] val joinedRow = new JoinedRow
-
- override final def hasNext: Boolean = hashTableIter.hasNext
-
- override final def next(): InternalRow = {
- val currentEntry = hashTableIter.next()
- val currentGroup = currentEntry.getKey
- val currentBuffer = currentEntry.getValue
- numOutputRows += 1
-
- var i = 0
- while (i < currentBuffer.length) {
- // Evaluating an aggregate buffer returns the result. No row is required since we
- // already added all rows in the group using update.
- aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
- i += 1
- }
- resultProjection(joinedRow(aggregateResults, currentGroup))
- }
- }
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 55e95769d3..91530bd637 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -45,6 +45,9 @@ case class Expand(
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true
+ override def references: AttributeSet =
+ AttributeSet(projections.flatten.flatMap(_.references))
+
private[this] val projection = {
if (outputsUnsafeRows) {
(exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 0f98fe88b2..a10d1edcc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -38,7 +38,6 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
DataSourceStrategy ::
DDLStrategy ::
TakeOrderedAndProject ::
- HashAggregation ::
Aggregation ::
LeftSemiJoin ::
EquiJoinSelection ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dd3bb33c57..d65cb1bae7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -146,148 +146,104 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object HashAggregation extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // Aggregations that can be performed in two phases, before and after the shuffle.
- case PartialAggregation(
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- groupingExpressions,
- partialComputation,
- child) if !canBeConvertedToNewAggregation(plan) =>
- execution.Aggregate(
- partial = false,
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- execution.Aggregate(
- partial = true,
- groupingExpressions,
- partialComputation,
- planLater(child))) :: Nil
-
- case _ => Nil
- }
-
- def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match {
- case a: logical.Aggregate =>
- if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) {
- a.newAggregation.isDefined
- } else {
- Utils.checkInvalidAggregateFunction2(a)
- false
- }
- case _ => false
- }
-
- def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
- exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
- }
-
/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
- sqlContext.conf.codegenEnabled =>
- val converted = p.newAggregation
- converted match {
- case None => Nil // Cannot convert to new aggregation code path.
- case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
- // A single aggregate expression might appear multiple times in resultExpressions.
- // In order to avoid evaluating an individual aggregate function multiple times, we'll
- // build a set of the distinct aggregate expressions and build a function which can
- // be used to re-write expressions so that they reference the single copy of the
- // aggregate function which actually gets computed.
- val aggregateExpressions = resultExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.distinct
- // For those distinct aggregate expressions, we create a map from the
- // aggregate function to the corresponding attribute of the function.
- val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
- val aggregateFunction = agg.aggregateFunction
- val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
- (aggregateFunction, agg.isDistinct) -> attribute
- }.toMap
-
- val (functionsWithDistinct, functionsWithoutDistinct) =
- aggregateExpressions.partition(_.isDistinct)
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- // This is a sanity check. We should not reach here when we have multiple distinct
- // column sets (aggregate.NewAggregation will not match).
- sys.error(
- "Multiple distinct column sets are not supported by the new aggregation" +
- "code path.")
- }
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+ // For those distinct aggregate expressions, we create a map from the
+ // aggregate function to the corresponding attribute of the function.
+ val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
+ val aggregateFunction = agg.aggregateFunction
+ val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ (aggregateFunction, agg.isDistinct) -> attribute
+ }.toMap
+
+ val (functionsWithDistinct, functionsWithoutDistinct) =
+ aggregateExpressions.partition(_.isDistinct)
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ // This is a sanity check. We should not reach here when we have multiple distinct
+ // column sets. Our MultipleDistinctRewriter should take care this case.
+ sys.error("You hit a query analyzer bug. Please report your query to " +
+ "Spark user mailing list.")
+ }
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
-
- // The original `resultExpressions` are a set of expressions which may reference
- // aggregate expressions, grouping column values, and constants. When aggregate operator
- // emits output rows, we will use `resultExpressions` to generate an output projection
- // which takes the grouping columns and final aggregate result buffer as input.
- // Thus, we must re-write the result expressions so that their attributes match up with
- // the attributes of the final result projection's input row:
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case AggregateExpression2(aggregateFunction, _, isDistinct) =>
- // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
- // so replace each aggregate expression by its corresponding attribute in the set:
- aggregateFunctionToAttribute(aggregateFunction, isDistinct)
- case expression =>
- // Since we're using `namedGroupingAttributes` to extract the grouping key
- // columns, we need to replace grouping key expressions with their corresponding
- // attributes. We do not rely on the equality check at here since attributes may
- // differ cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case AggregateExpression(aggregateFunction, _, isDistinct) =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ aggregateFunctionToAttribute(aggregateFunction, isDistinct)
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ val aggregateOperator =
+ if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+ if (functionsWithDistinct.nonEmpty) {
+ sys.error("Distinct columns cannot exist in Aggregate operator containing " +
+ "aggregate functions which don't support partial aggregation.")
+ } else {
+ aggregate.Utils.planAggregateWithoutPartial(
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
+ planLater(child))
}
+ } else if (functionsWithDistinct.isEmpty) {
+ aggregate.Utils.planAggregateWithoutDistinct(
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
+ planLater(child))
+ } else {
+ aggregate.Utils.planAggregateWithOneDistinct(
+ namedGroupingExpressions.map(_._2),
+ functionsWithDistinct,
+ functionsWithoutDistinct,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
+ planLater(child))
+ }
- val aggregateOperator =
- if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
- if (functionsWithDistinct.nonEmpty) {
- sys.error("Distinct columns cannot exist in Aggregate operator containing " +
- "aggregate functions which don't support partial aggregation.")
- } else {
- aggregate.Utils.planAggregateWithoutPartial(
- namedGroupingExpressions.map(_._2),
- aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
- planLater(child))
- }
- } else if (functionsWithDistinct.isEmpty) {
- aggregate.Utils.planAggregateWithoutDistinct(
- namedGroupingExpressions.map(_._2),
- aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
- planLater(child))
- } else {
- aggregate.Utils.planAggregateWithOneDistinct(
- namedGroupingExpressions.map(_._2),
- functionsWithDistinct,
- functionsWithoutDistinct,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
- planLater(child))
- }
-
- aggregateOperator
- }
+ aggregateOperator
case _ => Nil
}
@@ -422,18 +378,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
- case a @ logical.Aggregate(group, agg, child) => {
- val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
- if (useNewAggregation && a.newAggregation.isDefined) {
- // If this logical.Aggregate can be planned to use new aggregation code path
- // (i.e. it can be planned by the Strategy Aggregation), we will not use the old
- // aggregation code path.
- Nil
- } else {
- Utils.checkInvalidAggregateFunction2(a)
- execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
- }
- }
case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
execution.Window(
projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 99fb7a40b7..008478a6a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -35,9 +35,9 @@ import scala.collection.mutable.ArrayBuffer
abstract class AggregationIterator(
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
@@ -76,14 +76,14 @@ abstract class AggregationIterator(
// Initialize all AggregateFunctions by binding references if necessary,
// and set inputBufferOffset and mutableBufferOffset.
- protected val allAggregateFunctions: Array[AggregateFunction2] = {
+ protected val allAggregateFunctions: Array[AggregateFunction] = {
var mutableBufferOffset = 0
var inputBufferOffset: Int = initialInputBufferOffset
- val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ val functions = new Array[AggregateFunction](allAggregateExpressions.length)
var i = 0
while (i < allAggregateExpressions.length) {
val func = allAggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match {
+ val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
@@ -135,7 +135,7 @@ abstract class AggregationIterator(
}
// All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
- private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
// All imperative aggregate functions with mode Partial, PartialMerge, or Final.
@@ -172,7 +172,7 @@ abstract class AggregationIterator(
case (Some(Partial), None) =>
val updateExpressions = nonCompleteAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val expressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
@@ -204,7 +204,7 @@ abstract class AggregationIterator(
// allAggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
// This projection is used to merge buffer values for all expression-based aggregates.
val expressionAggMergeProjection =
@@ -225,7 +225,7 @@ abstract class AggregationIterator(
// Final-Complete
case (Some(Final), Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction2] =
+ val completeAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
// All imperative aggregate functions with mode Complete.
val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
@@ -248,7 +248,7 @@ abstract class AggregationIterator(
val mergeExpressions =
nonCompleteAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
} ++ completeOffsetExpressions
val finalExpressionAggMergeProjection =
newMutableProjection(mergeExpressions, mergeInputSchema)()
@@ -256,7 +256,7 @@ abstract class AggregationIterator(
val updateExpressions =
finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
@@ -282,7 +282,7 @@ abstract class AggregationIterator(
// Complete-only
case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction2] =
+ val completeAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
// All imperative aggregate functions with mode Complete.
val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
@@ -291,7 +291,7 @@ abstract class AggregationIterator(
val updateExpressions =
completeAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
@@ -353,7 +353,7 @@ abstract class AggregationIterator(
allAggregateFunctions.flatMap(_.aggBufferAttributes)
val evalExpressions = allAggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
+ case agg: AggregateFunction => NoOp
}
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index 4d37106e00..fb7f30c2ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 64c673064f..fe5c3195f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.execution.metric.LongSQLMetric
/**
- * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
+ * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
* sorted by values of [[groupingKeyAttributes]].
*/
class SortBasedAggregationIterator(
@@ -31,9 +31,9 @@ class SortBasedAggregationIterator(
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 15616915f7..1edde1e5a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
@@ -30,9 +30,9 @@ import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index ce8d592c36..0439144392 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -64,12 +64,12 @@ import org.apache.spark.sql.types.StructType
* @param groupingExpressions
* expressions for grouping keys
* @param nonCompleteAggregateExpressions
- * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
- * [[PartialMerge]], or [[Final]].
+ * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]],
+ * [[PartialMerge]], or [[Final]].
* @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
* outputs when they are stored in the final aggregation buffer.
* @param completeAggregateExpressions
- * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]].
* @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
* when they are stored in the final aggregation buffer.
* @param resultExpressions
@@ -83,9 +83,9 @@ import org.apache.spark.sql.types.StructType
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
- nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
@@ -106,7 +106,7 @@ class TungstenAggregationIterator(
// A Seq containing all AggregateExpressions.
// It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
// are at the beginning of the allAggregateExpressions.
- private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+ private[this] val allAggregateExpressions: Seq[AggregateExpression] =
nonCompleteAggregateExpressions ++ completeAggregateExpressions
// Check to make sure we do not have more than three modes in our AggregateExpressions.
@@ -150,10 +150,10 @@ class TungstenAggregationIterator(
// Initialize all AggregateFunctions by binding references, if necessary,
// and setting inputBufferOffset and mutableBufferOffset.
private def initializeAllAggregateFunctions(
- startingInputBufferOffset: Int): Array[AggregateFunction2] = {
+ startingInputBufferOffset: Int): Array[AggregateFunction] = {
var mutableBufferOffset = 0
var inputBufferOffset: Int = startingInputBufferOffset
- val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ val functions = new Array[AggregateFunction](allAggregateExpressions.length)
var i = 0
while (i < allAggregateExpressions.length) {
val func = allAggregateExpressions(i).aggregateFunction
@@ -195,7 +195,7 @@ class TungstenAggregationIterator(
functions
}
- private[this] var allAggregateFunctions: Array[AggregateFunction2] =
+ private[this] var allAggregateFunctions: Array[AggregateFunction] =
initializeAllAggregateFunctions(initialInputBufferOffset)
// Positions of those imperative aggregate functions in allAggregateFunctions.
@@ -263,7 +263,7 @@ class TungstenAggregationIterator(
case (Some(Partial), None) =>
val updateExpressions = allAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val imperativeAggregateFunctions: Array[ImperativeAggregate] =
allAggregateFunctions.collect { case func: ImperativeAggregate => func}
@@ -286,7 +286,7 @@ class TungstenAggregationIterator(
case (Some(PartialMerge), None) | (Some(Final), None) =>
val mergeExpressions = allAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val imperativeAggregateFunctions: Array[ImperativeAggregate] =
allAggregateFunctions.collect { case func: ImperativeAggregate => func}
@@ -307,11 +307,11 @@ class TungstenAggregationIterator(
// Final-Complete
case (Some(Final), Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction2] =
+ val completeAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
- val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ val nonCompleteAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
@@ -321,7 +321,7 @@ class TungstenAggregationIterator(
val mergeExpressions =
nonCompleteAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
} ++ completeOffsetExpressions
val finalMergeProjection =
newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
@@ -331,7 +331,7 @@ class TungstenAggregationIterator(
Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val completeUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
@@ -358,7 +358,7 @@ class TungstenAggregationIterator(
// Complete-only
case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[AggregateFunction2] =
+ val completeAggregateFunctions: Array[AggregateFunction] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
// All imperative aggregate functions with mode Complete.
val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
@@ -366,7 +366,7 @@ class TungstenAggregationIterator(
val updateExpressions = completeAggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
@@ -414,7 +414,7 @@ class TungstenAggregationIterator(
val joinedRow = new JoinedRow()
val evalExpressions = allAggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
+ case agg: AggregateFunction => NoOp
}
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
// These are the attributes of the row produced by `expressionAggEvalProjection`
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index d2f56e0fc1..20359c1e54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index eaafd83158..79abf2d592 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -28,8 +28,8 @@ object Utils {
def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
@@ -54,8 +54,8 @@ object Utils {
def planAggregateWithoutDistinct(
groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use TungstenAggregate.
@@ -137,9 +137,9 @@ object Utils {
def planAggregateWithOneDistinct(
groupingExpressions: Seq[NamedExpression],
- functionsWithDistinct: Seq[AggregateExpression2],
- functionsWithoutDistinct: Seq[AggregateExpression2],
- aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
+ functionsWithDistinct: Seq[AggregateExpression],
+ functionsWithoutDistinct: Seq[AggregateExpression],
+ aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
@@ -253,16 +253,16 @@ object Utils {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
+ case agg @ AggregateExpression(aggregateFunction, mode, true) =>
val rewrittenAggregateFunction = aggregateFunction.transformDown {
case expr if expr == distinctColumnExpression => distinctColumnAttribute
- }.asInstanceOf[AggregateFunction2]
+ }.asInstanceOf[AggregateFunction]
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
// they still can see distinct aggregations.
val rewrittenAggregateExpression =
- AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true)
+ AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true)
val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
(rewrittenAggregateExpression, aggregateFunctionAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 0b3192a6da..8cc25c2440 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.expressions
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
@@ -70,7 +70,7 @@ abstract class Aggregator[-A, B, C] {
implicit bEncoder: Encoder[B],
cEncoder: Encoder[C]): TypedColumn[A, C] = {
val expr =
- new AggregateExpression2(
+ new AggregateExpression(
TypedAggregateExpression(this),
Complete,
false)
@@ -78,4 +78,3 @@ abstract class Aggregator[-A, B, C] {
new TypedColumn[A, C](expr, encoderFor[C])
}
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 8b9247adea..fc873c04f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.{Column, catalyst}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
/**
@@ -141,40 +141,56 @@ class WindowSpec private[sql](
*/
private[sql] def withAggregate(aggregate: Column): Column = {
val windowExpr = aggregate.expr match {
- case Average(child) => WindowExpression(
- UnresolvedWindowFunction("avg", child :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Sum(child) => WindowExpression(
- UnresolvedWindowFunction("sum", child :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Count(child) => WindowExpression(
- UnresolvedWindowFunction("count", child :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case First(child, ignoreNulls) => WindowExpression(
- // TODO this is a hack for Hive UDAF first_value
- UnresolvedWindowFunction(
- "first_value",
- child :: ignoreNulls :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Last(child, ignoreNulls) => WindowExpression(
- // TODO this is a hack for Hive UDAF last_value
- UnresolvedWindowFunction(
- "last_value",
- child :: ignoreNulls :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Min(child) => WindowExpression(
- UnresolvedWindowFunction("min", child :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Max(child) => WindowExpression(
- UnresolvedWindowFunction("max", child :: Nil),
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case wf: WindowFunction => WindowExpression(
- wf,
- WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ // First, we check if we get an aggregate function without the DISTINCT keyword.
+ // Right now, we do not support using a DISTINCT aggregate function as a
+ // window function.
+ case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct =>
+ aggregateFunction match {
+ case Average(child) => WindowExpression(
+ UnresolvedWindowFunction("avg", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Sum(child) => WindowExpression(
+ UnresolvedWindowFunction("sum", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Count(child) => WindowExpression(
+ UnresolvedWindowFunction("count", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case First(child, ignoreNulls) => WindowExpression(
+ // TODO this is a hack for Hive UDAF first_value
+ UnresolvedWindowFunction(
+ "first_value",
+ child :: ignoreNulls :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Last(child, ignoreNulls) => WindowExpression(
+ // TODO this is a hack for Hive UDAF last_value
+ UnresolvedWindowFunction(
+ "last_value",
+ child :: ignoreNulls :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Min(child) => WindowExpression(
+ UnresolvedWindowFunction("min", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Max(child) => WindowExpression(
+ UnresolvedWindowFunction("max", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case x =>
+ throw new UnsupportedOperationException(s"$x is not supported in a window operation.")
+ }
+
+ case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct =>
+ throw new UnsupportedOperationException(
+ s"Distinct aggregate function ${aggregateFunction} is not supported " +
+ s"in window operation.")
+
+ case wf: WindowFunction =>
+ WindowExpression(
+ wf,
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+
case x =>
- throw new UnsupportedOperationException(s"$x is not supported in window operation.")
+ throw new UnsupportedOperationException(s"$x is not supported in a window operation.")
}
+
new Column(windowExpr)
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 258afadc76..11dbf391cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
@@ -109,7 +109,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
- AggregateExpression2(
+ AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
@@ -123,7 +123,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
- AggregateExpression2(
+ AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6d56542ee0..22104e4d48 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -76,6 +77,12 @@ object functions extends LegacyFunctions {
private def withExpr(expr: Expression): Column = Column(expr)
+ private def withAggregateFunction(
+ func: AggregateFunction,
+ isDistinct: Boolean = false): Column = {
+ Column(func.toAggregateExpression(isDistinct))
+ }
+
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
@@ -154,7 +161,9 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) }
+ def approxCountDistinct(e: Column): Column = withAggregateFunction {
+ HyperLogLogPlusPlus(e.expr)
+ }
/**
* Aggregate function: returns the approximate number of distinct items in a group.
@@ -170,8 +179,8 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def approxCountDistinct(e: Column, rsd: Double): Column = withExpr {
- ApproxCountDistinct(e.expr, rsd)
+ def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction {
+ HyperLogLogPlusPlus(e.expr, rsd, 0, 0)
}
/**
@@ -190,7 +199,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def avg(e: Column): Column = withExpr { Average(e.expr) }
+ def avg(e: Column): Column = withAggregateFunction { Average(e.expr) }
/**
* Aggregate function: returns the average of the values in a group.
@@ -226,7 +235,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def corr(column1: Column, column2: Column): Column = withExpr {
+ def corr(column1: Column, column2: Column): Column = withAggregateFunction {
Corr(column1.expr, column2.expr)
}
@@ -246,7 +255,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def count(e: Column): Column = withExpr {
+ def count(e: Column): Column = withAggregateFunction {
e.expr match {
// Turn count(*) into count(1)
case s: Star => Count(Literal(1))
@@ -269,8 +278,8 @@ object functions extends LegacyFunctions {
* @since 1.3.0
*/
@scala.annotation.varargs
- def countDistinct(expr: Column, exprs: Column*): Column = withExpr {
- CountDistinct((expr +: exprs).map(_.expr))
+ def countDistinct(expr: Column, exprs: Column*): Column = {
+ withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true)
}
/**
@@ -289,7 +298,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def first(e: Column): Column = withExpr { First(e.expr) }
+ def first(e: Column): Column = withAggregateFunction { new First(e.expr) }
/**
* Aggregate function: returns the first value of a column in a group.
@@ -305,7 +314,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) }
+ def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) }
/**
* Aggregate function: returns the last value in a group.
@@ -313,7 +322,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def last(e: Column): Column = withExpr { Last(e.expr) }
+ def last(e: Column): Column = withAggregateFunction { new Last(e.expr) }
/**
* Aggregate function: returns the last value of the column in a group.
@@ -329,7 +338,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def max(e: Column): Column = withExpr { Max(e.expr) }
+ def max(e: Column): Column = withAggregateFunction { Max(e.expr) }
/**
* Aggregate function: returns the maximum value of the column in a group.
@@ -363,7 +372,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def min(e: Column): Column = withExpr { Min(e.expr) }
+ def min(e: Column): Column = withAggregateFunction { Min(e.expr) }
/**
* Aggregate function: returns the minimum value of the column in a group.
@@ -379,7 +388,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def skewness(e: Column): Column = withExpr { Skewness(e.expr) }
+ def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) }
/**
* Aggregate function: alias for [[stddev_samp]].
@@ -387,7 +396,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) }
+ def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
/**
* Aggregate function: returns the unbiased sample standard deviation of
@@ -396,7 +405,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) }
+ def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
/**
* Aggregate function: returns the population standard deviation of
@@ -405,7 +414,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) }
+ def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) }
/**
* Aggregate function: returns the sum of all values in the expression.
@@ -413,7 +422,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def sum(e: Column): Column = withExpr { Sum(e.expr) }
+ def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) }
/**
* Aggregate function: returns the sum of all values in the given column.
@@ -429,7 +438,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.3.0
*/
- def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) }
+ def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true)
/**
* Aggregate function: returns the sum of distinct values in the expression.
@@ -445,7 +454,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) }
+ def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
/**
* Aggregate function: returns the unbiased variance of the values in a group.
@@ -453,7 +462,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) }
+ def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
/**
* Aggregate function: returns the population variance of the values in a group.
@@ -461,7 +470,7 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 1.6.0
*/
- def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) }
+ def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) }
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 3de277a79a..441a0c6d0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -237,34 +237,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-8828 sum should return null if all input values are null") {
- withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") {
- withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
- checkAnswer(
- sql("select sum(a), avg(a) from allNulls"),
- Seq(Row(null, null))
- )
- }
- withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
- checkAnswer(
- sql("select sum(a), avg(a) from allNulls"),
- Seq(Row(null, null))
- )
- }
- }
- withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
- withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
- checkAnswer(
- sql("select sum(a), avg(a) from allNulls"),
- Seq(Row(null, null))
- )
- }
- withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
- checkAnswer(
- sql("select sum(a), avg(a) from allNulls"),
- Seq(Row(null, null))
- )
- }
- }
+ checkAnswer(
+ sql("select sum(a), avg(a) from allNulls"),
+ Seq(Row(null, null))
+ )
}
private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
@@ -507,29 +483,22 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("literal in agg grouping expressions") {
- def literalInAggTest(): Unit = {
- checkAnswer(
- sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
- checkAnswer(
- sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
-
- checkAnswer(
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
- checkAnswer(
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
- checkAnswer(
- sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
- sql("SELECT 1, 2, sum(b) FROM testData2"))
- }
+ checkAnswer(
+ sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
+ Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+ checkAnswer(
+ sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
+ Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
- literalInAggTest()
- withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
- literalInAggTest()
- }
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ checkAnswer(
+ sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT 1, 2, sum(b) FROM testData2"))
}
test("aggregates with nulls") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index a229e5814d..e31c528f3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -21,16 +21,13 @@ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import scala.beans.{BeanInfo, BeanProperty}
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
-
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
+import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashSet
@@ -134,16 +131,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
- test("HyperLogLogUDT") {
- val hyperLogLogUDT = HyperLogLogUDT
- val hyperLogLog = new HyperLogLog(0.4)
- (1 to 10).foreach(i => hyperLogLog.offer(Row(i)))
-
- val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog))
- assert(actual.cardinality() === hyperLogLog.cardinality())
- assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes))
- }
-
test("OpenHashSetUDT") {
val openHashSetUDT = new OpenHashSetUDT(IntegerType)
val set = new OpenHashSet[Int]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 2076c573b5..44634dacbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext {
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = sqlContext.planner
import planner._
- val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
+ val plannedOption = Aggregation(query).headOption
val planned =
plannedOption.getOrElse(
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index cdd885ba14..4b4f5c6c45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -152,36 +152,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
- test("Aggregate metrics") {
- withSQLConf(
- SQLConf.UNSAFE_ENABLED.key -> "false",
- SQLConf.CODEGEN_ENABLED.key -> "false",
- SQLConf.TUNGSTEN_ENABLED.key -> "false") {
- // Assume the execution plan is
- // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
- val df = testData2.groupBy().count() // 2 partitions
- testSparkPlanMetrics(df, 1, Map(
- 2L -> ("Aggregate", Map(
- "number of input rows" -> 6L,
- "number of output rows" -> 2L)),
- 0L -> ("Aggregate", Map(
- "number of input rows" -> 2L,
- "number of output rows" -> 1L)))
- )
-
- // 2 partitions and each partition contains 2 keys
- val df2 = testData2.groupBy('a).count()
- testSparkPlanMetrics(df2, 1, Map(
- 2L -> ("Aggregate", Map(
- "number of input rows" -> 6L,
- "number of output rows" -> 4L)),
- 0L -> ("Aggregate", Map(
- "number of input rows" -> 4L,
- "number of output rows" -> 3L)))
- )
- }
- }
-
test("SortBasedAggregate metrics") {
// Because SortBasedAggregate may skip different rows if the number of partitions is different,
// this test should use the deterministic number of partitions.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index c5f69657f5..ba6204633b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -584,7 +584,6 @@ class HiveContext private[hive](
HiveTableScans,
DataSinks,
Scripts,
- HashAggregation,
Aggregation,
LeftSemiJoin,
EquiJoinSelection,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ab88c1e68f..6f8ed413a0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -38,6 +38,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.{AnalysisException, catalyst}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{logical, _}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
@@ -1508,9 +1509,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
/* Aggregate Functions */
- case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
- case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
- case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
+ Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true)
+ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) =>
+ Count(Literal(1)).toAggregateExpression()
/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index ea36c132bb..6bf2c53440 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -69,11 +69,7 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
- var originalUseAggregate2: Boolean = _
-
override def beforeAll(): Unit = {
- originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
- sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true")
val data1 = Seq[(Integer, Integer)](
(1, 10),
(null, -60),
@@ -120,7 +116,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
sqlContext.sql("DROP TABLE IF EXISTS agg1")
sqlContext.sql("DROP TABLE IF EXISTS agg2")
sqlContext.dropTempTable("emptyTable")
- sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString)
}
test("empty table") {
@@ -447,73 +442,80 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}
test("single distinct column set") {
- // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | min(distinct value1),
- | sum(distinct value1),
- | avg(value1),
- | avg(value2),
- | max(distinct value1)
- |FROM agg2
- """.stripMargin),
- Row(-60, 70.0, 101.0/9.0, 5.6, 100))
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | mydoubleavg(distinct value1),
- | avg(value1),
- | avg(value2),
- | key,
- | mydoubleavg(value1 - 1),
- | mydoubleavg(distinct value1) * 0.1,
- | avg(value1 + value2)
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
- Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
- Row(null, null, 3.0, 3, null, null, null) ::
- Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | key,
- | mydoubleavg(distinct value1),
- | mydoublesum(value2),
- | mydoublesum(distinct value1),
- | mydoubleavg(distinct value1),
- | mydoubleavg(value1)
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
- Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
- Row(3, null, 3.0, null, null, null) ::
- Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT
- | count(value1),
- | count(*),
- | count(1),
- | count(DISTINCT value1),
- | key
- |FROM agg2
- |GROUP BY key
- """.stripMargin),
- Row(3, 3, 3, 2, 1) ::
- Row(3, 4, 4, 2, 2) ::
- Row(0, 2, 2, 0, 3) ::
- Row(3, 4, 4, 3, null) :: Nil)
+ Seq(true, false).foreach { specializeSingleDistinctAgg =>
+ val conf =
+ (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key,
+ specializeSingleDistinctAgg.toString)
+ withSQLConf(conf) {
+ // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | min(distinct value1),
+ | sum(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | max(distinct value1)
+ |FROM agg2
+ """.stripMargin),
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100))
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | mydoubleavg(distinct value1),
+ | avg(value1),
+ | avg(value2),
+ | key,
+ | mydoubleavg(value1 - 1),
+ | mydoubleavg(distinct value1) * 0.1,
+ | avg(value1 + value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+ Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+ Row(null, null, 3.0, 3, null, null, null) ::
+ Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | mydoubleavg(distinct value1),
+ | mydoublesum(value2),
+ | mydoublesum(distinct value1),
+ | mydoubleavg(distinct value1),
+ | mydoubleavg(value1)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+ Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+ Row(3, null, 3.0, null, null, null) ::
+ Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value1),
+ | count(*),
+ | count(1),
+ | count(DISTINCT value1),
+ | key
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(3, 3, 3, 2, 1) ::
+ Row(3, 4, 4, 2, 2) ::
+ Row(0, 2, 2, 0, 3) ::
+ Row(3, 4, 4, 3, null) :: Nil)
+ }
+ }
}
test("single distinct multiple columns set") {
@@ -699,48 +701,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
-
- withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
- val errorMessage = intercept[SparkException] {
- val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
- val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
- }.getMessage
- assert(errorMessage.contains("java.lang.UnsupportedOperationException: " +
- "Corr only supports the new AggregateExpression2"))
- }
- }
-
- test("test Last implemented based on AggregateExpression1") {
- // TODO: Remove this test once we remove AggregateExpression1.
- import org.apache.spark.sql.functions._
- val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
- withSQLConf(
- SQLConf.SHUFFLE_PARTITIONS.key -> "1",
- SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
-
- checkAnswer(
- df.groupBy("i").agg(last("j")),
- df
- )
- }
- }
-
- test("error handling") {
- withSQLConf("spark.sql.useAggregate2" -> "false") {
- val errorMessage = intercept[AnalysisException] {
- sqlContext.sql(
- """
- |SELECT
- | key,
- | sum(value + 1.5 * key),
- | mydoublesum(value),
- | mydoubleavg(value)
- |FROM agg1
- |GROUP BY key
- """.stripMargin).collect()
- }.getMessage
- assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
- }
}
test("no aggregation function (SPARK-11486)") {