aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-11-10 16:25:22 -0800
committerReynold Xin <rxin@databricks.com>2015-11-10 16:25:22 -0800
commit3121e78168808c015fb21da8b0d44bb33649fb81 (patch)
tree167c44c7754fbb47839563692c24742d1e1b7b9a /sql
parente281b87398f1298cc3df8e0409c7040acdddce03 (diff)
downloadspark-3121e78168808c015fb21da8b0d44bb33649fb81.tar.gz
spark-3121e78168808c015fb21da8b0d44bb33649fb81.tar.bz2
spark-3121e78168808c015fb21da8b0d44bb33649fb81.zip
[SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1 and update toString of Exchange
https://issues.apache.org/jira/browse/SPARK-9830 This is the follow-up pr for https://github.com/apache/spark/pull/9556 to address davies' comments. Author: Yin Huai <yhuai@databricks.com> Closes #9607 from yhuai/removeAgg1-followup.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala127
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala10
10 files changed, 160 insertions, 54 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b1e14390b7..a9cd9a7703 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -532,7 +532,7 @@ class Analyzer(
case min: Min if isDistinct =>
AggregateExpression(min, Complete, isDistinct = false)
// We get an aggregate function, we need to wrap it in an AggregateExpression.
- case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct)
+ case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
// This function is not an aggregate function, just return the resolved one.
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 8322e9930c..5a4b0c1e39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -110,17 +110,21 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case aggExpr: AggregateExpression =>
- // TODO: Is it possible that the child of a agg function is another
- // agg function?
- aggExpr.aggregateFunction.children.foreach {
- // This is just a sanity check, our analysis rule PullOutNondeterministic should
- // already pull out those nondeterministic expressions and evaluate them in
- // a Project node.
- case child if !child.deterministic =>
+ aggExpr.aggregateFunction.children.foreach { child =>
+ child.foreach {
+ case agg: AggregateExpression =>
+ failAnalysis(
+ s"It is not allowed to use an aggregate function in the argument of " +
+ s"another aggregate function. Please use the inner aggregate function " +
+ s"in a sub-query.")
+ case other => // OK
+ }
+
+ if (!child.deterministic) {
failAnalysis(
s"nondeterministic expression ${expr.prettyString} should not " +
s"appear in the arguments of an aggregate function.")
- case child => // OK
+ }
}
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
@@ -133,19 +137,33 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
+ def checkSupportedGroupingDataType(
+ expressionString: String,
+ dataType: DataType): Unit = dataType match {
+ case BinaryType =>
+ failAnalysis(s"expression $expressionString cannot be used in " +
+ s"grouping expression because it is in binary type or its inner field is " +
+ s"in binary type")
+ case a: ArrayType =>
+ failAnalysis(s"expression $expressionString cannot be used in " +
+ s"grouping expression because it is in array type or its inner field is " +
+ s"in array type")
+ case m: MapType =>
+ failAnalysis(s"expression $expressionString cannot be used in " +
+ s"grouping expression because it is in map type or its inner field is " +
+ s"in map type")
+ case s: StructType =>
+ s.fields.foreach { f =>
+ checkSupportedGroupingDataType(expressionString, f.dataType)
+ }
+ case udt: UserDefinedType[_] =>
+ checkSupportedGroupingDataType(expressionString, udt.sqlType)
+ case _ => // OK
+ }
+
def checkValidGroupingExprs(expr: Expression): Unit = {
- expr.dataType match {
- case BinaryType =>
- failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case a: ArrayType =>
- failAnalysis(s"array type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case m: MapType =>
- failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
- "in grouping expression")
- case _ => // OK
- }
+ checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
+
if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 7f9e503470..94ac4bf09b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -34,7 +34,7 @@ case class Average(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 984ce7f24d..de5872ab11 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
override def dataType: DataType = DoubleType
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 5b9eb7ae02..2748009623 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = resultType
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index c005ec9657..cfb042e0aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -32,7 +32,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = resultType
override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+ Seq(TypeCollection(LongType, DoubleType, DecimalType))
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 5a2368e329..2e7c3bd67b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._
+import scala.beans.{BeanProperty, BeanInfo}
+
+@BeanInfo
+private[sql] case class GroupableData(@BeanProperty data: Int)
+
+private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
+
+ override def sqlType: DataType = IntegerType
+
+ override def serialize(obj: Any): Int = {
+ obj match {
+ case groupableData: GroupableData => groupableData.data
+ }
+ }
+
+ override def deserialize(datum: Any): GroupableData = {
+ datum match {
+ case data: Int => GroupableData(data)
+ }
+ }
+
+ override def userClass: Class[GroupableData] = classOf[GroupableData]
+
+ private[spark] override def asNullable: GroupableUDT = this
+}
+
+@BeanInfo
+private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
+
+private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
+
+ override def sqlType: DataType = ArrayType(IntegerType)
+
+ override def serialize(obj: Any): ArrayData = {
+ obj match {
+ case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
+ }
+ }
+
+ override def deserialize(datum: Any): UngroupableData = {
+ datum match {
+ case data: Array[Int] => UngroupableData(data)
+ }
+ }
+
+ override def userClass: Class[UngroupableData] = classOf[UngroupableData]
+
+ private[spark] override def asNullable: UngroupableUDT = this
+}
+
case class TestFunction(
children: Seq[Expression],
inputTypes: Seq[AbstractDataType])
@@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest {
assert(error.message.contains("Conflicting attributes"))
}
- test("aggregation can't work on binary and map types") {
- val plan =
- Aggregate(
- AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
- Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
- LocalRelation(
- AttributeReference("a", BinaryType)(exprId = ExprId(2)),
- AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+ test("check grouping expression data types") {
+ def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = {
+ val plan =
+ Aggregate(
+ AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil,
+ Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", dataType)(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+ shouldSuccess match {
+ case true =>
+ assertAnalysisSuccess(plan, true)
+ case false =>
+ assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
+ }
- assertAnalysisError(plan,
- "binary type expression a cannot be used in grouping expression" :: Nil)
+ }
- val plan2 =
- Aggregate(
- AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
- Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
- LocalRelation(
- AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
- AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+ val supportedDataTypes = Seq(
+ StringType,
+ NullType, BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+ DateType, TimestampType,
+ new StructType()
+ .add("f1", FloatType, nullable = true)
+ .add("f2", StringType, nullable = true),
+ new GroupableUDT())
+ supportedDataTypes.foreach { dataType =>
+ checkDataType(dataType, shouldSuccess = true)
+ }
- assertAnalysisError(plan2,
- "map type expression a cannot be used in grouping expression" :: Nil)
+ val unsupportedDataTypes = Seq(
+ BinaryType,
+ ArrayType(IntegerType),
+ MapType(StringType, LongType),
+ new StructType()
+ .add("f1", FloatType, nullable = true)
+ .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
+ new UngroupableUDT())
+ unsupportedDataTypes.foreach { dataType =>
+ checkDataType(dataType, shouldSuccess = false)
+ }
+ }
- val plan3 =
+ test("we should fail analysis when we find nested aggregate functions") {
+ val plan =
Aggregate(
- AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
- Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil,
+ Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil,
LocalRelation(
- AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
+ AttributeReference("a", IntegerType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
- assertAnalysisError(plan3,
- "array type expression a cannot be used in grouping expression" :: Nil)
+ assertAnalysisError(
+ plan,
+ "It is not allowed to use an aggregate function in the argument of " +
+ "another aggregate function." :: Nil)
}
test("Join can't work on binary and map types") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 89e196c066..57d7d30e0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -474,6 +474,7 @@ private[spark] object SQLConf {
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
+ val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index a4ce328c1a..b733b26987 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -44,14 +44,14 @@ case class Exchange(
override def nodeName: String = {
val extraInfo = coordinator match {
case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
- "Shuffle"
+ s"(coordinator id: ${System.identityHashCode(coordinator)})"
case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
- "May shuffle"
- case None => "Shuffle without coordinator"
+ s"(coordinator id: ${System.identityHashCode(coordinator)})"
+ case None => ""
}
val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
- s"$simpleNodeName($extraInfo)"
+ s"${simpleNodeName}${extraInfo}"
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index e5f60b15e7..8b2755a587 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
}
(keyValueOutput, runFunc)
+ case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) =>
+ val runFunc = (sqlContext: SQLContext) => {
+ logWarning(
+ s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " +
+ s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " +
+ s"continue to be true.")
+ Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true"))
+ }
+ (keyValueOutput, runFunc)
+
// Configures a single property.
case Some((key, Some(value))) =>
val runFunc = (sqlContext: SQLContext) => {