From 3121e78168808c015fb21da8b0d44bb33649fb81 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 10 Nov 2015 16:25:22 -0800 Subject: [SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1 and update toString of Exchange https://issues.apache.org/jira/browse/SPARK-9830 This is the follow-up pr for https://github.com/apache/spark/pull/9556 to address davies' comments. Author: Yin Huai Closes #9607 from yhuai/removeAgg1-followup. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 58 ++++++---- .../catalyst/expressions/aggregate/Average.scala | 2 +- .../expressions/aggregate/CentralMomentAgg.scala | 2 +- .../catalyst/expressions/aggregate/Stddev.scala | 2 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 +- .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 127 +++++++++++++++++---- .../main/scala/org/apache/spark/sql/SQLConf.scala | 1 + .../org/apache/spark/sql/execution/Exchange.scala | 8 +- .../org/apache/spark/sql/execution/commands.scala | 10 ++ 10 files changed, 160 insertions(+), 54 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b1e14390b7..a9cd9a7703 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -532,7 +532,7 @@ class Analyzer( case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct) + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8322e9930c..5a4b0c1e39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -110,17 +110,21 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => - // TODO: Is it possible that the child of a agg function is another - // agg function? - aggExpr.aggregateFunction.children.foreach { - // This is just a sanity check, our analysis rule PullOutNondeterministic should - // already pull out those nondeterministic expressions and evaluate them in - // a Project node. - case child if !child.deterministic => + aggExpr.aggregateFunction.children.foreach { child => + child.foreach { + case agg: AggregateExpression => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { failAnalysis( s"nondeterministic expression ${expr.prettyString} should not " + s"appear in the arguments of an aggregate function.") - case child => // OK + } } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( @@ -133,19 +137,33 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkSupportedGroupingDataType( + expressionString: String, + dataType: DataType): Unit = dataType match { + case BinaryType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in binary type or its inner field is " + + s"in binary type") + case a: ArrayType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in array type or its inner field is " + + s"in array type") + case m: MapType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in map type or its inner field is " + + s"in map type") + case s: StructType => + s.fields.foreach { f => + checkSupportedGroupingDataType(expressionString, f.dataType) + } + case udt: UserDefinedType[_] => + checkSupportedGroupingDataType(expressionString, udt.sqlType) + case _ => // OK + } + def checkValidGroupingExprs(expr: Expression): Unit = { - expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case a: ArrayType => - failAnalysis(s"array type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK - } + checkSupportedGroupingDataType(expr.prettyString, expr.dataType) + if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 7f9e503470..94ac4bf09b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -34,7 +34,7 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function average") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 984ce7f24d..de5872ab11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 5b9eb7ae02..2748009623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function stddev") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index c005ec9657..cfb042e0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -32,7 +32,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function sum") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5a2368e329..2e7c3bd67b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ +import scala.beans.{BeanProperty, BeanInfo} + +@BeanInfo +private[sql] case class GroupableData(@BeanProperty data: Int) + +private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { + + override def sqlType: DataType = IntegerType + + override def serialize(obj: Any): Int = { + obj match { + case groupableData: GroupableData => groupableData.data + } + } + + override def deserialize(datum: Any): GroupableData = { + datum match { + case data: Int => GroupableData(data) + } + } + + override def userClass: Class[GroupableData] = classOf[GroupableData] + + private[spark] override def asNullable: GroupableUDT = this +} + +@BeanInfo +private[sql] case class UngroupableData(@BeanProperty data: Array[Int]) + +private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { + + override def sqlType: DataType = ArrayType(IntegerType) + + override def serialize(obj: Any): ArrayData = { + obj match { + case groupableData: UngroupableData => new GenericArrayData(groupableData.data) + } + } + + override def deserialize(datum: Any): UngroupableData = { + datum match { + case data: Array[Int] => UngroupableData(data) + } + } + + override def userClass: Class[UngroupableData] = classOf[UngroupableData] + + private[spark] override def asNullable: UngroupableUDT = this +} + case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest { assert(error.message.contains("Conflicting attributes")) } - test("aggregation can't work on binary and map types") { - val plan = - Aggregate( - AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + test("check grouping expression data types") { + def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = { + val plan = + Aggregate( + AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", dataType)(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + shouldSuccess match { + case true => + assertAnalysisSuccess(plan, true) + case false => + assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil) + } - assertAnalysisError(plan, - "binary type expression a cannot be used in grouping expression" :: Nil) + } - val plan2 = - Aggregate( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + val supportedDataTypes = Seq( + StringType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", StringType, nullable = true), + new GroupableUDT()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) + } - assertAnalysisError(plan2, - "map type expression a cannot be used in grouping expression" :: Nil) + val unsupportedDataTypes = Seq( + BinaryType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new UngroupableUDT()) + unsupportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = false) + } + } - val plan3 = + test("we should fail analysis when we find nested aggregate functions") { + val plan = Aggregate( - AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil, LocalRelation( - AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)), + AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - assertAnalysisError(plan3, - "array type expression a cannot be used in grouping expression" :: Nil) + assertAnalysisError( + plan, + "It is not allowed to use an aggregate function in the argument of " + + "another aggregate function." :: Nil) } test("Join can't work on binary and map types") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 89e196c066..57d7d30e0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -474,6 +474,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index a4ce328c1a..b733b26987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -44,14 +44,14 @@ case class Exchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - "Shuffle" + s"(coordinator id: ${System.identityHashCode(coordinator)})" case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => - "May shuffle" - case None => "Shuffle without coordinator" + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" } val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" - s"$simpleNodeName($extraInfo)" + s"${simpleNodeName}${extraInfo}" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e5f60b15e7..8b2755a587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + + s"continue to be true.") + Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { -- cgit v1.2.3