aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala)235
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala1073
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala74
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala14
35 files changed, 364 insertions, 1543 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 3f351b07b3..7c2b8a9407 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean
+
+ protected[spark] def specializeSingleDistinctAggPlanning: Boolean
}
/**
@@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf {
override def caseSensitiveAnalysis: Boolean = {
throw new UnsupportedOperationException
}
+
+ protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
+ throw new UnsupportedOperationException
+ }
}
/** A CatalystConf that can be used for local testing. */
-case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf
+case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
+ protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index cd717c09f8..2a132d8b82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -22,6 +22,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DataTypeParser
@@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val function: Parser[Expression] =
( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
if (lexical.normalizeKeyword(udfName) == "count") {
- Count(Literal(1))
+ AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false)
} else {
throw new AnalysisException(s"invalid expression $udfName(*)")
}
@@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
- case "sum" => SumDistinct(exprs.head)
- case "count" => CountDistinct(exprs)
+ case "count" =>
+ aggregate.Count(exprs).toAggregateExpression(isDistinct = true)
case _ => UnresolvedFunction(udfName, exprs, isDistinct = true)
}
}
| APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
if (lexical.normalizeKeyword(udfName) == "count") {
- ApproxCountDistinct(exp)
+ AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false)
} else {
throw new AnalysisException(s"invalid function approximate $udfName")
}
@@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
{ case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
if (lexical.normalizeKeyword(udfName) == "count") {
- ApproxCountDistinct(exp, s.toDouble)
+ AggregateExpression(
+ HyperLogLogPlusPlus(exp, s.toDouble, 0, 0),
+ mode = Complete,
+ isDistinct = false)
} else {
throw new AnalysisException(s"invalid function approximate($s) $udfName")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 899ee67352..b1e14390b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
@@ -79,6 +79,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
+ DistinctAggregationRewriter(conf) ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -525,21 +526,14 @@ class Analyzer(
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
- // We get an aggregate function built based on AggregateFunction2 interface.
- // So, we wrap it in AggregateExpression2.
- case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
- // Currently, our old aggregate function interface supports SUM(DISTINCT ...)
- // and COUTN(DISTINCT ...).
- case sumDistinct: SumDistinct => sumDistinct
- case countDistinct: CountDistinct => countDistinct
- // DISTINCT is not meaningful with Max and Min.
- case max: Max if isDistinct => max
- case min: Min if isDistinct => min
- // For other aggregate functions, DISTINCT keyword is not supported for now.
- // Once we converted to the new code path, we will allow using DISTINCT keyword.
- case other: AggregateExpression1 if isDistinct =>
- failAnalysis(s"$name does not support DISTINCT keyword.")
- // If it does not have DISTINCT keyword, we will return it as is.
+ // DISTINCT is not meaningful for a Max or a Min.
+ case max: Max if isDistinct =>
+ AggregateExpression(max, Complete, isDistinct = false)
+ case min: Min if isDistinct =>
+ AggregateExpression(min, Complete, isDistinct = false)
+ // We get an aggregate function, we need to wrap it in an AggregateExpression.
+ case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct)
+ // This function is not an aggregate function, just return the resolved one.
case other => other
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 98d6637c06..8322e9930c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -108,7 +109,19 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
- case _: AggregateExpression => // OK
+ case aggExpr: AggregateExpression =>
+ // TODO: Is it possible that the child of a agg function is another
+ // agg function?
+ aggExpr.aggregateFunction.children.foreach {
+ // This is just a sanity check, our analysis rule PullOutNondeterministic should
+ // already pull out those nondeterministic expressions and evaluate them in
+ // a Project node.
+ case child if !child.deterministic =>
+ failAnalysis(
+ s"nondeterministic expression ${expr.prettyString} should not " +
+ s"appear in the arguments of an aggregate function.")
+ case child => // OK
+ }
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
@@ -120,14 +133,26 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match {
- case BinaryType =>
- failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case m: MapType =>
- failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case _ => // OK
+ def checkValidGroupingExprs(expr: Expression): Unit = {
+ expr.dataType match {
+ case BinaryType =>
+ failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case a: ArrayType =>
+ failAnalysis(s"array type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case m: MapType =>
+ failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
+ "in grouping expression")
+ case _ => // OK
+ }
+ if (!expr.deterministic) {
+ // This is just a sanity check, our analysis rule PullOutNondeterministic should
+ // already pull out those nondeterministic expressions and evaluate them in
+ // a Project node.
+ failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
+ s"appear in grouping expression.")
+ }
}
aggregateExprs.foreach(checkValidAggregateExpression)
@@ -179,7 +204,8 @@ trait CheckAnalysis {
s"unresolved operator ${operator.simpleString}")
case o if o.expressions.exists(!_.deterministic) &&
- !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
+ !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] =>
+ // The rule above is used to check Aggregate operator.
failAnalysis(
s"""nondeterministic expressions are only allowed in Project or Filter, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 9b22ce2619..397eff0568 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -15,215 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions.aggregate
+package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.IntegerType
/**
- * Utility functions used by the query planner to convert our plan to new aggregation code path.
- */
-object Utils {
-
- // Check if the DataType given cannot be part of a group by clause.
- private def isUnGroupable(dt: DataType): Boolean = dt match {
- case _: ArrayType | _: MapType => true
- case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
- case _ => false
- }
-
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
- !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))
-
- private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate if supportsGroupingKeySchema(p) =>
-
- val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
- case expressions.Average(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Average(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Count(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.CountDistinct(children) =>
- val child = if (children.size > 1) {
- DropAnyNull(CreateStruct(children))
- } else {
- children.head
- }
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.First(child, ignoreNulls) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child, ignoreNulls),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Kurtosis(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Kurtosis(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Last(child, ignoreNulls) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child, ignoreNulls),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Max(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Max(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Min(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Min(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Skewness(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Skewness(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.StddevPop(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.StddevPop(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.StddevSamp(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.StddevSamp(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Sum(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.SumDistinct(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.Corr(left, right) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Corr(left, right),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.ApproxCountDistinct(child, rsd) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.VariancePop(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.VariancePop(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.VarianceSamp(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.VarianceSamp(child),
- mode = aggregate.Complete,
- isDistinct = false)
- })
-
- // Check if there is any expressions.AggregateExpression1 left.
- // If so, we cannot convert this plan.
- val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
- // For every expressions, check if it contains AggregateExpression1.
- expr.find {
- case agg: expressions.AggregateExpression1 => true
- case other => false
- }.isDefined
- }
-
- // Check if there are multiple distinct columns.
- // TODO remove this.
- val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.toSet.toSeq
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val hasMultipleDistinctColumnSets =
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- true
- } else {
- false
- }
-
- if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
- case other => None
- }
-
- def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
- // If the plan cannot be converted, we will do a final round check to see if the original
- // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
- // we need to throw an exception.
- val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg.aggregateFunction
- }
- }.distinct
- if (aggregateFunction2s.nonEmpty) {
- // For functions implemented based on the new interface, prepare a list of function names.
- val invalidFunctions = {
- if (aggregateFunction2s.length > 1) {
- s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
- s"and ${aggregateFunction2s.head.nodeName} are"
- } else {
- s"${aggregateFunction2s.head.nodeName} is"
- }
- }
- val errorMessage =
- s"${invalidFunctions} implemented based on the new Aggregate Function " +
- s"interface and it cannot be used with functions implemented based on " +
- s"the old Aggregate Function interface."
- throw new AnalysisException(errorMessage)
- }
- }
-
- def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate =>
- val converted = doConvert(p)
- if (converted.isDefined) {
- converted
- } else {
- checkInvalidAggregateFunction2(p)
- None
- }
- case other => None
- }
-}
-
-/**
- * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
+ * This rule rewrites an aggregate query with distinct aggregations into an expanded double
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
@@ -298,9 +100,11 @@ object Utils {
* we could improve this in the current rule by applying more advanced expression cannocalization
* techniques.
*/
-object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.resolved => p
+ // We need to wait until this Aggregate operator is resolved.
case a: Aggregate => rewrite(a)
case p => p
}
@@ -310,7 +114,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
- case ae: AggregateExpression2 => ae
+ case ae: AggregateExpression => ae
}
}
@@ -319,8 +123,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
- // Only continue to rewrite if there is more than one distinct group.
- if (distinctAggGroups.size > 1) {
+ val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
+ // When the flag is set to specialize single distinct agg planning,
+ // we will rely on our Aggregation strategy to handle queries with a single
+ // distinct column and this aggregate operator does have grouping expressions.
+ distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty)
+ } else {
+ distinctAggGroups.size >= 1
+ }
+ if (shouldRewrite) {
// Create the attributes for the grouping id and the group by clause.
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
@@ -332,11 +143,11 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
- af: AggregateFunction2)(
- attrs: Expression => Expression): AggregateFunction2 = {
+ af: AggregateFunction)(
+ attrs: Expression => Expression): AggregateFunction = {
af.withNewChildren(af.children.map {
case afc => attrs(afc)
- }).asInstanceOf[AggregateFunction2]
+ }).asInstanceOf[AggregateFunction]
}
// Setup unique distinct aggregate children.
@@ -381,7 +192,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
// Select the result of the first aggregate in the last aggregate.
- val result = AggregateExpression2(
+ val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d4334d1628..dfa749d1af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
@@ -177,6 +178,7 @@ object FunctionRegistry {
expression[ToRadians]("radians"),
// aggregate functions
+ expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 84e2b1366f..bf2bff0243 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -295,14 +296,17 @@ object HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
- case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
- case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
- case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
- case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
- case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
+ case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
+ Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
}
}
@@ -562,12 +566,6 @@ object HiveTypeCoercion {
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
- case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest.
- case SumDistinct(e @ IntegralType()) if e.dataType != LongType =>
- SumDistinct(Cast(e, LongType))
- case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType =>
- SumDistinct(Cast(e, DoubleType))
-
case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eae17c86dd..6485bdfb30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -141,6 +141,10 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
+ override def prettyString: String = {
+ s"${name}(${children.map(_.prettyString).mkString(",")})"
+ }
+
override def toString: String = s"'$name(${children.mkString(",")})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index d8df66430a..af594c25c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -23,6 +23,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.types._
@@ -144,17 +145,18 @@ package object dsl {
}
}
- def sum(e: Expression): Expression = Sum(e)
- def sumDistinct(e: Expression): Expression = SumDistinct(e)
- def count(e: Expression): Expression = Count(e)
- def countDistinct(e: Expression*): Expression = CountDistinct(e)
+ def sum(e: Expression): Expression = Sum(e).toAggregateExpression()
+ def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true)
+ def count(e: Expression): Expression = Count(e).toAggregateExpression()
+ def countDistinct(e: Expression*): Expression =
+ Count(e).toAggregateExpression(isDistinct = true)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
- ApproxCountDistinct(e, rsd)
- def avg(e: Expression): Expression = Average(e)
- def first(e: Expression): Expression = First(e)
- def last(e: Expression): Expression = Last(e)
- def min(e: Expression): Expression = Min(e)
- def max(e: Expression): Expression = Max(e)
+ HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
+ def avg(e: Expression): Expression = Average(e).toAggregateExpression()
+ def first(e: Expression): Expression = new First(e).toAggregateExpression()
+ def last(e: Expression): Expression = new Last(e).toAggregateExpression()
+ def min(e: Expression): Expression = Min(e).toAggregateExpression()
+ def max(e: Expression): Expression = Max(e).toAggregateExpression()
def upper(e: Expression): Expression = Upper(e)
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c8c20ada5f..7f9e503470 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Average(child: Expression) extends DeclarativeAggregate {
@@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select avg(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
- private val resultType = child.dataType match {
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
+
+ private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}
- private val sumDataType = child.dataType match {
+ private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}
- private val sum = AttributeReference("sum", sumDataType)()
- private val count = AttributeReference("count", LongType)()
+ private lazy val sum = AttributeReference("sum", sumDataType)()
+ private lazy val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = sum :: count :: Nil
+ override lazy val aggBufferAttributes = sum :: count :: Nil
- override val initialValues = Seq(
+ override lazy val initialValues = Seq(
/* sum = */ Cast(Literal(0), sumDataType),
/* count = */ Literal(0L)
)
- override val updateExpressions = Seq(
+ override lazy val updateExpressions = Seq(
/* sum = */
Add(
sum,
@@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate {
/* count = */ If(IsNull(child), count, count + 1L)
)
- override val mergeExpressions = Seq(
+ override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
// If all input are nulls, count will be 0 and we will get null after the division.
- override val evaluateExpression = child.dataType match {
+ override lazy val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index ef08b025ff..984ce7f24d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
override def dataType: DataType = DoubleType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select avg(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 832338378f..00d7436b71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -35,6 +37,9 @@ case class Corr(
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate {
+ def this(left: Expression, right: Expression) =
+ this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def children: Seq[Expression] = Seq(left, right)
override def nullable: Boolean = false
@@ -43,6 +48,16 @@ case class Corr(
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"corr requires that both arguments are double type, " +
+ s"not (${left.dataType}, ${right.dataType}).")
+ }
+ }
+
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def inputAggBufferAttributes: Seq[AttributeReference] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index ec0c8b483a..09a1da9200 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val count = AttributeReference("count", LongType)()
+ private lazy val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = count :: Nil
+ override lazy val aggBufferAttributes = count :: Nil
- override val initialValues = Seq(
+ override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)
- override val updateExpressions = Seq(
+ override lazy val updateExpressions = Seq(
/* count = */ If(IsNull(child), count, count + 1L)
)
- override val mergeExpressions = Seq(
+ override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
- override val evaluateExpression = Cast(count, LongType)
+ override lazy val evaluateExpression = Cast(count, LongType)
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
+
+object Count {
+ def apply(children: Seq[Expression]): Count = {
+ // This is used to deal with COUNT DISTINCT. When we have multiple
+ // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
+ // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
+ // null in the arguments, we will not count that row. So, we use DropAnyNull at here
+ // to return a null when any field of the created STRUCT is null.
+ val child = if (children.size > 1) {
+ DropAnyNull(CreateStruct(children))
+ } else {
+ children.head
+ }
+ Count(child)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 9028143015..35f57426fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val first = AttributeReference("first", child.dataType)()
+ private lazy val first = AttributeReference("first", child.dataType)()
- private val valueSet = AttributeReference("valueSet", BooleanType)()
+ private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
- override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val initialValues: Seq[Literal] = Seq(
/* first = */ Literal.create(null, child.dataType),
/* valueSet = */ Literal.create(false, BooleanType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* first = */ If(Or(valueSet, IsNull(child)), first, child),
@@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
}
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
// For first, we can just check if valueSet.left is set to true. If it is set
// to true, we use first.right. If not, we use first.right (even if valueSet.right is
// false, we are safe to do so because first.right will be null in this case).
@@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
)
}
- override val evaluateExpression: AttributeReference = first
+ override lazy val evaluateExpression: AttributeReference = first
override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 8d341ee630..8a95c541f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -22,6 +22,7 @@ import java.util
import com.clearspring.analytics.hash.MurmurHash
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus(
extends ImperativeAggregate {
import HyperLogLogPlusPlus._
+ def this(child: Expression) = {
+ this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+ }
+
+ def this(child: Expression, relativeSD: Expression) = {
+ this(
+ child = child,
+ relativeSD = relativeSD match {
+ case Literal(d: Double, DoubleType) => d
+ case _ =>
+ throw new AnalysisException("The second argument should be a double literal.")
+ },
+ mutableAggBufferOffset = 0,
+ inputAggBufferOffset = 0)
+ }
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 6da39e7143..bae78d9849 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -24,6 +24,8 @@ case class Kurtosis(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index 8636bfe8d0..be7e12d7a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val last = AttributeReference("last", child.dataType)()
+ private lazy val last = AttributeReference("last", child.dataType)()
- override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val initialValues: Seq[Literal] = Seq(
/* last = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(child), last, child)
@@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
}
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
@@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
}
}
- override val evaluateExpression: AttributeReference = last
+ override lazy val evaluateExpression: AttributeReference = last
override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index b9d75ad452..61cae44cd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Max(child: Expression) extends DeclarativeAggregate {
@@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val max = AttributeReference("max", child.dataType)()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function max")
- override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+ private lazy val max = AttributeReference("max", child.dataType)()
- override val initialValues: Seq[Literal] = Seq(
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+
+ override lazy val initialValues: Seq[Literal] = Seq(
/* max = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val greatest = Greatest(Seq(max.left, max.right))
Seq(
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
)
}
- override val evaluateExpression: AttributeReference = max
+ override lazy val evaluateExpression: AttributeReference = max
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 5ed9cd348d..242456d9e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val min = AttributeReference("min", child.dataType)()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function min")
- override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+ private lazy val min = AttributeReference("min", child.dataType)()
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+
+ override lazy val initialValues: Seq[Expression] = Seq(
/* min = */ Literal.create(null, child.dataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val least = Least(Seq(min.left, min.right))
Seq(
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
)
}
- override val evaluateExpression: AttributeReference = min
+ override lazy val evaluateExpression: AttributeReference = min
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index 0def7ddfd9..c593074fa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -24,6 +24,8 @@ case class Skewness(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 3f47ffe13c..5b9eb7ae02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select stddev(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
- private val resultType = DoubleType
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
- private val count = AttributeReference("count", resultType)()
- private val avg = AttributeReference("avg", resultType)()
- private val mk = AttributeReference("mk", resultType)()
+ private lazy val resultType = DoubleType
- override val aggBufferAttributes = count :: avg :: mk :: Nil
+ private lazy val count = AttributeReference("count", resultType)()
+ private lazy val avg = AttributeReference("avg", resultType)()
+ private lazy val mk = AttributeReference("mk", resultType)()
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
+
+ override lazy val initialValues: Seq[Expression] = Seq(
/* count = */ Cast(Literal(0), resultType),
/* avg = */ Cast(Literal(0), resultType),
/* mk = */ Cast(Literal(0), resultType)
)
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
@@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
)
}
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
// count merge
val newCount = count.left + count.right
@@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
)
}
- override val evaluateExpression: Expression = {
+ override lazy val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 7f8adbc56a..c005ec9657 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
case class Sum(child: Expression) extends DeclarativeAggregate {
@@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType
- // Expected input data type.
- // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
- // new version at planning time (after analysis phase). For now, NullType is added at here
- // to make it resolved when we have cases like `select sum(null)`.
- // We can use our analyzer to cast NullType to the default data type of the NumericType once
- // we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
- private val resultType = child.dataType match {
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sum")
+
+ private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
// TODO: Remove this line once we remove the NullType from inputTypes.
@@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
case _ => child.dataType
}
- private val sumDataType = resultType
+ private lazy val sumDataType = resultType
- private val sum = AttributeReference("sum", sumDataType)()
+ private lazy val sum = AttributeReference("sum", sumDataType)()
- private val zero = Cast(Literal(0), sumDataType)
+ private lazy val zero = Cast(Literal(0), sumDataType)
- override val aggBufferAttributes = sum :: Nil
+ override lazy val aggBufferAttributes = sum :: Nil
- override val initialValues: Seq[Expression] = Seq(
+ override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)
- override val updateExpressions: Seq[Expression] = Seq(
+ override lazy val updateExpressions: Seq[Expression] = Seq(
/* sum = */
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
- override val mergeExpressions: Seq[Expression] = {
+ override lazy val mergeExpressions: Seq[Expression] = {
val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
@@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
)
}
- override val evaluateExpression: Expression = Cast(sum, resultType)
+ override lazy val evaluateExpression: Expression = Cast(sum, resultType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ec63534e52..ede2da2805 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression,
}
}
-case class VariancePop(child: Expression,
+case class VariancePop(
+ child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
+ def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 5c5b3d1ccd..3b441de34a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,23 +17,24 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-/** The mode of an [[AggregateFunction2]]. */
+/** The mode of an [[AggregateFunction]]. */
private[sql] sealed trait AggregateMode
/**
- * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation.
+ * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object Partial extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
@@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
@@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
private[sql] case object Final extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
+ * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
@@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable {
}
/**
- * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
*/
-private[sql] case class AggregateExpression2(
- aggregateFunction: AggregateFunction2,
+private[sql] case class AggregateExpression(
+ aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean) extends AggregateExpression {
+ isDistinct: Boolean)
+ extends Expression
+ with Unevaluable {
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
@@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2(
AttributeSet(childReferences)
}
+ override def prettyString: String = aggregateFunction.prettyString
+
override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
}
@@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2(
* combined aggregation buffer which concatenates the aggregation buffers of the individual
* aggregate functions.
*
- * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of
+ * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
* aggregate functions.
*/
-sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes {
+sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
@@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ /**
+ * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because
+ * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+ * and the flag indicating if this aggregation is distinct aggregation or not.
+ * An [[AggregateFunction]] should not be used without being wrapped in
+ * an [[AggregateExpression]].
+ */
+ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)
+
+ /**
+ * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct
+ * field of the [[AggregateExpression]] to the given value because
+ * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+ * and the flag indicating if this aggregation is distinct aggregation or not.
+ * An [[AggregateFunction]] should not be used without being wrapped in
+ * an [[AggregateExpression]].
+ */
+ def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
+ }
}
/**
@@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
* `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
* and `inputAggBufferAttributes`.
*/
-abstract class ImperativeAggregate extends AggregateFunction2 {
+abstract class ImperativeAggregate extends AggregateFunction {
/**
* The offset of this function's first buffer value in the underlying shared mutable aggregation
@@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You
* can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and
* `evaluateExpressions`.
+ *
+ * Please note that children of an aggregate function can be unresolved (it will happen when
+ * we create this function in DataFrame API). So, if there is any fields in
+ * the implemented class that need to access fields of its children, please make
+ * those fields `lazy val`s.
*/
abstract class DeclarativeAggregate
- extends AggregateFunction2
+ extends AggregateFunction
with Serializable
with Unevaluable {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
deleted file mode 100644
index 3dcf7915d7..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ /dev/null
@@ -1,1073 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils}
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
-
-
-trait AggregateExpression extends Expression with Unevaluable
-
-trait AggregateExpression1 extends AggregateExpression {
-
- /**
- * Aggregate expressions should not be foldable.
- */
- override def foldable: Boolean = false
-
- /**
- * Creates a new instance that can be used to compute this aggregate expression for a group
- * of input rows/
- */
- def newInstance(): AggregateFunction1
-}
-
-/**
- * Represents an aggregation that has been rewritten to be performed in two steps.
- *
- * @param finalEvaluation an aggregate expression that evaluates to same final result as the
- * original aggregation.
- * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial
- * data sets and are required to compute the `finalEvaluation`.
- */
-case class SplitEvaluation(
- finalEvaluation: Expression,
- partialEvaluations: Seq[NamedExpression])
-
-/**
- * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
- * These partial evaluations can then be combined to compute the actual answer.
- */
-trait PartialAggregate1 extends AggregateExpression1 {
-
- /**
- * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
- */
- def asPartial: SplitEvaluation
-}
-
-/**
- * A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
- */
-abstract class AggregateFunction1 extends LeafExpression with Serializable {
-
- /** Base should return the generic aggregate expression that this function is computing */
- val base: AggregateExpression1
-
- override def nullable: Boolean = base.nullable
- override def dataType: DataType = base.dataType
-
- def update(input: InternalRow): Unit
-
- override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- throw new UnsupportedOperationException(
- "AggregateFunction1 should not be used for generated aggregates")
- }
-}
-
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
-
- override def asPartial: SplitEvaluation = {
- val partialMin = Alias(Min(child), "PartialMin")()
- SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
- }
-
- override def newInstance(): MinFunction = new MinFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForOrderingExpr(child.dataType, "function min")
-}
-
-case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
- val cmp = GreaterThan(currentMin, expr)
-
- override def update(input: InternalRow): Unit = {
- if (currentMin.value == null) {
- currentMin.value = expr.eval(input)
- } else if (cmp.eval(input) == true) {
- currentMin.value = expr.eval(input)
- }
- }
-
- override def eval(input: InternalRow): Any = currentMin.value
-}
-
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
-
- override def asPartial: SplitEvaluation = {
- val partialMax = Alias(Max(child), "PartialMax")()
- SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
- }
-
- override def newInstance(): MaxFunction = new MaxFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForOrderingExpr(child.dataType, "function max")
-}
-
-case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
- val cmp = LessThan(currentMax, expr)
-
- override def update(input: InternalRow): Unit = {
- if (currentMax.value == null) {
- currentMax.value = expr.eval(input)
- } else if (cmp.eval(input) == true) {
- currentMax.value = expr.eval(input)
- }
- }
-
- override def eval(input: InternalRow): Any = currentMax.value
-}
-
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
-
- override def asPartial: SplitEvaluation = {
- val partialCount = Alias(Count(child), "PartialCount")()
- SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
- }
-
- override def newInstance(): CountFunction = new CountFunction(child, this)
-}
-
-case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- var count: Long = _
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1L
- }
- }
-
- override def eval(input: InternalRow): Any = count
-}
-
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = expressions
-
- override def nullable: Boolean = false
- override def dataType: DataType = LongType
- override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})"
- override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
- SplitEvaluation(
- CombineSetsAndCount(partialSet.toAttribute),
- partialSet :: Nil)
- }
-}
-
-case class CountDistinctFunction(
- @transient expr: Seq[Expression],
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = expressions
- override def nullable: Boolean = false
- override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType)
- override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance(): CollectHashSetFunction =
- new CollectHashSetFunction(expressions, this)
-}
-
-case class CollectHashSetFunction(
- @transient expr: Seq[Expression],
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- seen
- }
-}
-
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = LongType
- override def toString: String = s"CombineAndCount($inputSet)"
- override def newInstance(): CombineSetsAndCountFunction = {
- new CombineSetsAndCountFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndCountFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next)
- }
- }
-
- override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */
-private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
-
- override def sqlType: DataType = BinaryType
-
- /** Since we are using HyperLogLog internally, usually it will not be called. */
- override def serialize(obj: Any): Array[Byte] =
- obj.asInstanceOf[HyperLogLog].getBytes
-
-
- /** Since we are using HyperLogLog internally, usually it will not be called. */
- override def deserialize(datum: Any): HyperLogLog =
- HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
-
- override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
-}
-
-case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression1 {
-
- override def nullable: Boolean = false
- override def dataType: DataType = HyperLogLogUDT
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance(): ApproxCountDistinctPartitionFunction = {
- new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
- }
-}
-
-case class ApproxCountDistinctPartitionFunction(
- expr: Expression,
- base: AggregateExpression1,
- relativeSD: Double)
- extends AggregateFunction1 {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- hyperLogLog.offer(evaluatedExpr)
- }
- }
-
- override def eval(input: InternalRow): Any = hyperLogLog
-}
-
-case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
- extends UnaryExpression with AggregateExpression1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance(): ApproxCountDistinctMergeFunction = {
- new ApproxCountDistinctMergeFunction(child, this, relativeSD)
- }
-}
-
-case class ApproxCountDistinctMergeFunction(
- expr: Expression,
- base: AggregateExpression1,
- relativeSD: Double)
- extends AggregateFunction1 {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
- }
-
- override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
-}
-
-case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
- extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = false
- override def dataType: LongType.type = LongType
- override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
-
- override def asPartial: SplitEvaluation = {
- val partialCount =
- Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
-
- SplitEvaluation(
- ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
- partialCount :: Nil)
- }
-
- override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
-}
-
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def prettyName: String = "avg"
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 4 digits after decimal point, like Hive
- DecimalType.bounded(precision + 4, scale + 4)
- case _ =>
- DoubleType
- }
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
- // partialSum already increase the precision by 10
- val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
- SplitEvaluation(
- Cast(Divide(castedSum, castedCount), dataType),
- partialCount :: partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
- val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
- SplitEvaluation(
- Divide(castedSum, castedCount),
- partialCount :: partialSum :: Nil)
- }
- }
-
- override def newInstance(): AverageFunction = new AverageFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function average")
-}
-
-case class AverageFunction(expr: Expression, base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- private val calcType =
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- expr.dataType
- }
-
- private val zero = Cast(Literal(0), calcType)
-
- private var count: Long = _
- private val sum = MutableLiteral(zero.eval(null), calcType)
-
- private def addFunction(value: Any) = Add(sum,
- Cast(Literal.create(value, expr.dataType), calcType))
-
- override def eval(input: InternalRow): Any = {
- if (count == 0L) {
- null
- } else {
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- val dt = DecimalType.bounded(precision + 14, scale + 4)
- Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
- case _ =>
- Divide(
- Cast(sum, dataType),
- Cast(Literal(count), dataType)).eval(null)
- }
- }
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1
- sum.update(addFunction(evaluatedExpr), input)
- }
- }
-}
-
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 10 digits left of decimal point, like Hive
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- child.dataType
- }
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(_, _) =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Cast(Sum(partialSum.toAttribute), dataType),
- partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Sum(partialSum.toAttribute),
- partialSum :: Nil)
- }
- }
-
- override def newInstance(): SumFunction = new SumFunction(child, this)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sum")
-}
-
-case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
-
- private val calcType =
- expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- expr.dataType
- }
-
- private val zero = Cast(Literal(0), calcType)
-
- private val sum = MutableLiteral(null, calcType)
-
- private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
-
- override def update(input: InternalRow): Unit = {
- sum.update(addFunction, input)
- }
-
- override def eval(input: InternalRow): Any = {
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(sum, dataType).eval(null)
- case _ => sum.eval(null)
- }
- }
-}
-
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
- def this() = this(null)
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- // Add 10 digits left of decimal point, like Hive
- DecimalType.bounded(precision + 10, scale)
- case _ =>
- child.dataType
- }
- override def toString: String = s"sum(distinct $child)"
- override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
- SplitEvaluation(
- CombineSetsAndSum(partialSet.toAttribute, this),
- partialSet :: Nil)
- }
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
-}
-
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- private val seen = new scala.collection.mutable.HashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- seen += evaluatedExpr
- }
- }
-
- override def eval(input: InternalRow): Any = {
- if (seen.size == 0) {
- null
- } else {
- Cast(Literal(
- seen.reduceLeft(
- dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- dataType).eval(null)
- }
- }
-}
-
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
- def this() = this(null, null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = base.dataType
- override def toString: String = s"CombineAndSum($inputSet)"
- override def newInstance(): CombineSetsAndSumFunction = {
- new CombineSetsAndSumFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndSumFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next())
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
- if (casted.size == 0) {
- null
- } else {
- Cast(Literal(
- casted.iterator.map(f => f.get(0, null)).reduceLeft(
- base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- base.dataType).eval(null)
- }
- }
-}
-
-case class First(
- child: Expression,
- ignoreNullsExpr: Expression)
- extends UnaryExpression with PartialAggregate1 {
-
- def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})"
-
- override def asPartial: SplitEvaluation = {
- val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
- SplitEvaluation(
- First(partialFirst.toAttribute, ignoreNulls),
- partialFirst :: Nil)
- }
- override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
-}
-
-object First {
- def apply(child: Expression): First = First(child, ignoreNulls = false)
-
- def apply(child: Expression, ignoreNulls: Boolean): First =
- First(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class FirstFunction(
- expr: Expression,
- ignoreNulls: Boolean,
- base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
- private[this] var result: Any = null
-
- private[this] var valueSet: Boolean = false
-
- override def update(input: InternalRow): Unit = {
- if (!valueSet) {
- val value = expr.eval(input)
- // When we have not set the result, we will set the result if we respect nulls
- // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
- if (!ignoreNulls || (ignoreNulls && value != null)) {
- result = value
- valueSet = true
- }
- }
- }
-
- override def eval(input: InternalRow): Any = result
-}
-
-case class Last(
- child: Expression,
- ignoreNullsExpr: Expression)
- extends UnaryExpression with PartialAggregate1 {
-
- def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
- override def references: AttributeSet = child.references
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
-
- override def asPartial: SplitEvaluation = {
- val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
- SplitEvaluation(
- Last(partialLast.toAttribute, ignoreNulls),
- partialLast :: Nil)
- }
- override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
-}
-
-object Last {
- def apply(child: Expression): Last = Last(child, ignoreNulls = false)
-
- def apply(child: Expression, ignoreNulls: Boolean): Last =
- Last(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class LastFunction(
- expr: Expression,
- ignoreNulls: Boolean,
- base: AggregateExpression1)
- extends AggregateFunction1 {
-
- def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
- var result: Any = null
-
- override def update(input: InternalRow): Unit = {
- val value = expr.eval(input)
- if (!ignoreNulls || (ignoreNulls && value != null)) {
- result = value
- }
- }
-
- override def eval(input: InternalRow): Any = {
- result
- }
-}
-
-/**
- * Calculate Pearson Correlation Coefficient for the given columns.
- * Only support AggregateExpression2.
- *
- */
-case class Corr(left: Expression, right: Expression)
- extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
- override def nullable: Boolean = false
- override def dataType: DoubleType.type = DoubleType
- override def toString: String = s"corr($left, $right)"
- override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException(
- "Corr only supports the new AggregateExpression2 and can only be used " +
- "when spark.sql.useAggregate2 = true")
- }
-}
-
-// Compute standard deviation based on online algorithm specified here:
-// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
- override def nullable: Boolean = true
- override def dataType: DataType = DoubleType
-
- def isSample: Boolean
-
- override def asPartial: SplitEvaluation = {
- val partialStd = Alias(ComputePartialStd(child), "PartialStddev")()
- SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil)
- }
-
- override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
-
-}
-
-// Compute the population standard deviation of a column
-case class StddevPop(child: Expression) extends StddevAgg1(child) {
-
- override def toString: String = s"stddev_pop($child)"
- override def isSample: Boolean = false
-}
-
-// Compute the sample standard deviation of a column
-case class StddevSamp(child: Expression) extends StddevAgg1(child) {
-
- override def toString: String = s"stddev_samp($child)"
- override def isSample: Boolean = true
-}
-
-case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
- def this() = this(null)
-
- override def children: Seq[Expression] = child :: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = ArrayType(DoubleType)
- override def toString: String = s"computePartialStddev($child)"
- override def newInstance(): ComputePartialStdFunction =
- new ComputePartialStdFunction(child, this)
-}
-
-case class ComputePartialStdFunction (
- expr: Expression,
- base: AggregateExpression1
- ) extends AggregateFunction1 {
-
- def this() = this(null, null) // Required for serialization
-
- private val computeType = DoubleType
- private val zero = Cast(Literal(0), computeType)
- private var partialCount: Long = 0L
-
- // the mean of data processed so far
- private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
- // update average based on this formula:
- // avg = avg + (value - avg)/count
- private def avgAddFunction (value: Literal): Expression = {
- val delta = Subtract(Cast(value, computeType), partialAvg)
- Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType)))
- }
-
- // the sum of squares of difference from mean
- private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
- // update sum of square of difference from mean based on following formula:
- // Mk = Mk + (value - preAvg) * (value - updatedAvg)
- private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = {
- val delta1 = Subtract(Cast(value, computeType), prePartialAvg)
- val delta2 = Subtract(Cast(value, computeType), partialAvg)
- Add(partialMk, Multiply(delta1, delta2))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- val exprValue = Literal.create(evaluatedExpr, expr.dataType)
- val prePartialAvg = partialAvg.copy()
- partialCount += 1
- partialAvg.update(avgAddFunction(exprValue), input)
- partialMk.update(mkAddFunction(exprValue, prePartialAvg), input)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null),
- partialAvg.eval(null),
- partialMk.eval(null)))
- }
-}
-
-case class MergePartialStd(
- child: Expression,
- isSample: Boolean
-) extends UnaryExpression with AggregateExpression1 {
- def this() = this(null, false) // required for serialization
-
- override def children: Seq[Expression] = child:: Nil
- override def nullable: Boolean = false
- override def dataType: DataType = DoubleType
- override def toString: String = s"MergePartialStd($child)"
- override def newInstance(): MergePartialStdFunction = {
- new MergePartialStdFunction(child, this, isSample)
- }
-}
-
-case class MergePartialStdFunction(
- expr: Expression,
- base: AggregateExpression1,
- isSample: Boolean
-) extends AggregateFunction1 {
- def this() = this (null, null, false) // Required for serialization
-
- private val computeType = DoubleType
- private val zero = Cast(Literal(0), computeType)
- private val combineCount = MutableLiteral(zero.eval(null), computeType)
- private val combineAvg = MutableLiteral(zero.eval(null), computeType)
- private val combineMk = MutableLiteral(zero.eval(null), computeType)
-
- private def avgUpdateFunction(preCount: Expression,
- partialCount: Expression,
- partialAvg: Expression): Expression = {
- Divide(Add(Multiply(combineAvg, preCount),
- Multiply(partialAvg, partialCount)),
- Add(preCount, partialCount))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData]
-
- if (evaluatedExpr != null) {
- val exprValue = evaluatedExpr.toArray(computeType)
- val (partialCount, partialAvg, partialMk) =
- (Literal.create(exprValue(0), computeType),
- Literal.create(exprValue(1), computeType),
- Literal.create(exprValue(2), computeType))
-
- if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) {
- val preCount = combineCount.copy()
- combineCount.update(Add(combineCount, partialCount), input)
-
- val preAvg = combineAvg.copy()
- val avgDelta = Subtract(partialAvg, preAvg)
- val mkDelta = Multiply(Multiply(avgDelta, avgDelta),
- Divide(Multiply(preCount, partialCount),
- combineCount))
-
- // update average based on following formula
- // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount)
- combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input)
-
- // update sum of square differences from mean based on following formula
- // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount)
- combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input)
- }
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long]
-
- if (count == 0) null
- else if (count < 2) zero.eval(null)
- else {
- // when total count > 2
- // stddev_samp = sqrt (combineMk/(combineCount -1))
- // stddev_pop = sqrt (combineMk/combineCount)
- val varCol = {
- if (isSample) {
- Divide(combineMk, Cast(Literal(count - 1), computeType))
- }
- else {
- Divide(combineMk, Cast(Literal(count), computeType))
- }
- }
- Sqrt(varCol).eval(null)
- }
- }
-}
-
-case class StddevFunction(
- expr: Expression,
- base: AggregateExpression1,
- isSample: Boolean
-) extends AggregateFunction1 {
-
- def this() = this(null, null, false) // Required for serialization
-
- private val computeType = DoubleType
- private var curCount: Long = 0L
- private val zero = Cast(Literal(0), computeType)
- private val curAvg = MutableLiteral(zero.eval(null), computeType)
- private val curMk = MutableLiteral(zero.eval(null), computeType)
-
- private def curAvgAddFunction(value: Literal): Expression = {
- val delta = Subtract(Cast(value, computeType), curAvg)
- Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType)))
- }
- private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = {
- val delta1 = Subtract(Cast(value, computeType), preAvg)
- val delta2 = Subtract(Cast(value, computeType), curAvg)
- Add(curMk, Multiply(delta1, delta2))
- }
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- val preAvg: MutableLiteral = curAvg.copy()
- val exprValue = Literal.create(evaluatedExpr, expr.dataType)
- curCount += 1L
- curAvg.update(curAvgAddFunction(exprValue), input)
- curMk.update(curMkAddFunction(exprValue, preAvg), input)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- if (curCount == 0) null
- else if (curCount < 2) zero.eval(null)
- else {
- // when total count > 2,
- // stddev_samp = sqrt(curMk/(curCount - 1))
- // stddev_pop = sqrt(curMk/curCount)
- val varCol = {
- if (isSample) {
- Divide(curMk, Cast(Literal(curCount - 1), computeType))
- }
- else {
- Divide(curMk, Cast(Literal(curCount), computeType))
- }
- }
- Sqrt(varCol).eval(null)
- }
- }
-}
-
-// placeholder
-case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "kurtosis"
-}
-
-// placeholder
-case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "skewness"
-}
-
-// placeholder
-case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "var_pop"
-}
-
-// placeholder
-case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
- override def newInstance(): AggregateFunction1 = {
- throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
- "please set spark.sql.useAggregate2 = true")
- }
-
- override def nullable: Boolean = false
-
- override def dataType: DoubleType.type = DoubleType
-
- override def foldable: Boolean = false
-
- override def prettyName: String = "var_samp"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d222dfa33a..f4dba67f13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.LeftOuter
@@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, _, child))
- if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
+ if (child.outputSet -- e.references -- a.references).nonEmpty =>
+ a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
@@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+ case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+ Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
@@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ Count(expr) if !expr.nullable => Count(Literal(1))
+ case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+ // This rule should be only triggered when isDistinct field is false.
+ AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+ case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
- case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
Cast(
- Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)),
+ Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b975b904a..6f4f11406d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -85,80 +85,6 @@ object PhysicalOperation extends PredicateHelper {
}
/**
- * Matches a logical aggregation that can be performed on distributed data in two steps. The first
- * operates on the data in each partition performing partial aggregation for each group. The second
- * occurs after the shuffle and completes the aggregation.
- *
- * This pattern will only match if all aggregate expressions can be computed partially and will
- * return the rewritten aggregation expressions for both phases.
- *
- * The returned values for this match are as follows:
- * - Grouping attributes for the final aggregation.
- * - Aggregates for the final aggregation.
- * - Grouping expressions for the partial aggregation.
- * - Partial aggregate expressions.
- * - Input to the aggregation.
- */
-object PartialAggregation {
- type ReturnType =
- (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
-
- def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
- // Collect all aggregate expressions.
- val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
- // Collect all aggregate expressions that can be computed partially.
- val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
-
- // Only do partial aggregation if supported by all aggregate expressions.
- if (allAggregates.size == partialAggregates.size) {
- // Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
- partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
-
- // We need to pass all grouping expressions though so the grouping can happen a second
- // time. However some of them might be unnamed so we alias them allowing them to be
- // referenced in the second aggregation.
- val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
- groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }
-
- // Replace aggregations with a new expression that computes the result from the already
- // computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
- case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
- partialEvaluations(new TreeNodeRef(e)).finalEvaluation
-
- case e: Expression =>
- namedGroupingExpressions.collectFirst {
- case (expr, ne) if expr semanticEquals e => ne.toAttribute
- }.getOrElse(e)
- }).asInstanceOf[Seq[NamedExpression]]
-
- val partialComputation = namedGroupingExpressions.map(_._2) ++
- partialEvaluations.values.flatMap(_.partialEvaluations)
-
- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-
- Some(
- (namedGroupingAttributes,
- rewrittenAggregateExpressions,
- groupingExpressions,
- partialComputation,
- child))
- } else {
- None
- }
- case _ => None
- }
-}
-
-
-/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*
* Null-safe equality will be transformed into equality as joining key (replace null with default
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0ec9f08571..b9db7838db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/** Returns all of the expressions present in this query plan operator. */
def expressions: Seq[Expression] = {
+ // Recursively find all expressions from a traversable.
+ def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap {
+ case e: Expression => e :: Nil
+ case s: Traversable[_] => seqToExpressions(s)
+ case other => Nil
+ }
+
productIterator.flatMap {
case e: Expression => e :: Nil
case Some(e: Expression) => e :: Nil
- case seq: Traversable[_] => seq.flatMap {
- case e: Expression => e :: Nil
- case other => Nil
- }
+ case seq: Traversable[_] => seqToExpressions(seq)
case other => Nil
}.toSeq
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d771088d69..764f8aaebd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -219,8 +219,6 @@ case class Aggregate(
!expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
}
- lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
-
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index fbdd3a7776..5a2368e329 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -171,16 +171,18 @@ class AnalysisErrorSuite extends AnalysisTest {
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+ // Since we manually construct the logical plan at here and Sum only accetp
+ // LongType, DoubleType, and DecimalType. We use LongType as the type of a.
val plan =
Aggregate(
Nil,
- Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
+ Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
- AttributeReference("a", IntegerType)(exprId = ExprId(2))))
+ AttributeReference("a", LongType)(exprId = ExprId(2))))
assert(plan.resolved)
- assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil)
+ assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil)
}
test("error test for self-join") {
@@ -196,7 +198,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan =
Aggregate(
AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
- Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
@@ -207,13 +209,24 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan2 =
Aggregate(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
- Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
assertAnalysisError(plan2,
"map type expression a cannot be used in grouping expression" :: Nil)
+
+ val plan3 =
+ Aggregate(
+ AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+ assertAnalysisError(plan3,
+ "array type expression a cannot be used in grouping expression" :: Nil)
}
test("Join can't work on binary and map types") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 71d2939ecf..65f09b46af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest {
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
- assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
+ assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved)
}
test("analyze project") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 40c4ae7920..fed591fd90 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index c9bcc68f02..b902982add 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{TypeCollection, StringType}
@@ -140,15 +141,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for aggregates") {
+ // We use AggregateFunction directly at here because the error will be thrown from it
+ // instead of from AggregateExpression, which is the wrapper of an AggregateFunction.
+
// We will cast String to Double for sum and average
assertSuccess(Sum('stringField))
- assertSuccess(SumDistinct('stringField))
assertSuccess(Average('stringField))
assertError(Min('complexField), "min does not support ordering on type")
assertError(Max('complexField), "max does not support ordering on type")
assertError(Sum('booleanField), "function sum requires numeric type")
- assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type")
assertError(Average('booleanField), "function average requires numeric type")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index e67606288f..8aaefa8493 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Rand(5L) + Literal(1) as Symbol("c1"),
- Sum('a) as Symbol("c2"))
+ sum('a) as Symbol("c2"))
val optimized = Optimize.execute(originalQuery.analyze)
@@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Rand(5L) + Literal(1.0) as Symbol("c1"),
- Sum('a) as Symbol("c2"))
+ sum('a) as Symbol("c2"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ed810a1280..0290fafe87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -68,7 +68,7 @@ class FilterPushdownSuite extends PlanTest {
test("column pruning for group") {
val originalQuery =
testRelation
- .groupBy('a)('a, Count('b))
+ .groupBy('a)('a, count('b))
.select('a)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -84,7 +84,7 @@ class FilterPushdownSuite extends PlanTest {
test("column pruning for group with alias") {
val originalQuery =
testRelation
- .groupBy('a)('a as 'c, Count('b))
+ .groupBy('a)('a as 'c, count('b))
.select('c)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -656,7 +656,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: push down filter when filter on group by expression") {
val originalQuery = testRelation
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.select('a, 'c)
.where('a === 2)
@@ -664,7 +664,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer = testRelation
.where('a === 2)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.analyze
comparePlans(optimized, correctAnswer)
}
@@ -672,7 +672,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: don't push down filter when filter not on group by expression") {
val originalQuery = testRelation
.select('a, 'b)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -683,7 +683,7 @@ class FilterPushdownSuite extends PlanTest {
test("aggregate: push down filters partially which are subset of group by expressions") {
val originalQuery = testRelation
.select('a, 'b)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L && 'a === 3)
val optimized = Optimize.execute(originalQuery.analyze)
@@ -691,7 +691,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer = testRelation
.select('a, 'b)
.where('a === 3)
- .groupBy('a)('a, Count('b) as 'c)
+ .groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze