aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-07-30 10:30:37 -0700
committerYin Huai <yhuai@databricks.com>2015-07-30 10:32:12 -0700
commit5363ed71568c3e7c082146d654a9c669d692d894 (patch)
treee9a4f33ce2e045bb02c6e6b774bea96a7ee27bb8 /sql/core
parent7bbf02f0bddefd19985372af79e906a38bc528b6 (diff)
downloadspark-5363ed71568c3e7c082146d654a9c669d692d894.tar.gz
spark-5363ed71568c3e7c082146d654a9c669d692d894.tar.bz2
spark-5363ed71568c3e7c082146d654a9c669d692d894.zip
[SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility
JIRA: https://issues.apache.org/jira/browse/SPARK-9361 Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate` and reuse it. In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert` should be moved to catalyst because it actually doesn't deal with execution details. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #7677 from viirya/refactor_aggregate and squashes the following commits: babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate 9a589d7 [Liang-Chi Hsieh] Fix scala style. 0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala144
2 files changed, 16 insertions, 162 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f3ef066528..52a9b02d37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -193,11 +193,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
- def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
- aggregate.Utils.tryConvert(
- plan,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
+ def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match {
+ case a: logical.Aggregate =>
+ if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) {
+ a.newAggregation.isDefined
+ } else {
+ Utils.checkInvalidAggregateFunction2(a)
+ false
+ }
+ case _ => false
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
@@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case p: logical.Aggregate =>
- val converted =
- aggregate.Utils.tryConvert(
- p,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled)
+ case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
+ sqlContext.conf.codegenEnabled =>
+ val converted = p.newAggregation
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
@@ -377,17 +378,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
- val useNewAggregation =
- aggregate.Utils.tryConvert(
- a,
- sqlContext.conf.useSqlAggregate2,
- sqlContext.conf.codegenEnabled).isDefined
- if (useNewAggregation) {
+ val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
+ if (useNewAggregation && a.newAggregation.isDefined) {
// If this logical.Aggregate can be planned to use new aggregation code path
// (i.e. it can be planned by the Strategy Aggregation), we will not use the old
// aggregation code path.
Nil
} else {
+ Utils.checkInvalidAggregateFunction2(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 6549c87752..03635baae4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
- val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
- case array: ArrayType => true
- case map: MapType => true
- case struct: StructType => true
- case _ => false
- }
-
- !hasComplexTypes
- }
-
- private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
- case p: Aggregate if supportsGroupingKeySchema(p) =>
- val converted = p.transformExpressionsDown {
- case expressions.Average(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Average(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Count(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- // We do not support multiple COUNT DISTINCT columns for now.
- case expressions.CountDistinct(children) if children.length == 1 =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(children.head),
- mode = aggregate.Complete,
- isDistinct = true)
-
- case expressions.First(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Last(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Max(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Max(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Min(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Min(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.Sum(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = false)
-
- case expressions.SumDistinct(child) =>
- aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Sum(child),
- mode = aggregate.Complete,
- isDistinct = true)
- }
- // Check if there is any expressions.AggregateExpression1 left.
- // If so, we cannot convert this plan.
- val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
- // For every expressions, check if it contains AggregateExpression1.
- expr.find {
- case agg: expressions.AggregateExpression1 => true
- case other => false
- }.isDefined
- }
-
- // Check if there are multiple distinct columns.
- val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg
- }
- }.toSet.toSeq
- val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
- val hasMultipleDistinctColumnSets =
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- true
- } else {
- false
- }
-
- if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
- case other => None
- }
-
- private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
- // If the plan cannot be converted, we will do a final round check to if the original
- // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
- // we need to throw an exception.
- val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression2 => agg.aggregateFunction
- }
- }.distinct
- if (aggregateFunction2s.nonEmpty) {
- // For functions implemented based on the new interface, prepare a list of function names.
- val invalidFunctions = {
- if (aggregateFunction2s.length > 1) {
- s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
- s"and ${aggregateFunction2s.head.nodeName} are"
- } else {
- s"${aggregateFunction2s.head.nodeName} is"
- }
- }
- val errorMessage =
- s"${invalidFunctions} implemented based on the new Aggregate Function " +
- s"interface and it cannot be used with functions implemented based on " +
- s"the old Aggregate Function interface."
- throw new AnalysisException(errorMessage)
- }
- }
-
- def tryConvert(
- plan: LogicalPlan,
- useNewAggregation: Boolean,
- codeGenEnabled: Boolean): Option[Aggregate] = plan match {
- case p: Aggregate if useNewAggregation && codeGenEnabled =>
- val converted = tryConvert(p)
- if (converted.isDefined) {
- converted
- } else {
- checkInvalidAggregateFunction2(p)
- None
- }
- case p: Aggregate =>
- checkInvalidAggregateFunction2(p)
- None
- case other => None
- }
-
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],