aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-07-30 10:30:37 -0700
committerYin Huai <yhuai@databricks.com>2015-07-30 10:32:12 -0700
commit5363ed71568c3e7c082146d654a9c669d692d894 (patch)
treee9a4f33ce2e045bb02c6e6b774bea96a7ee27bb8 /sql/catalyst
parent7bbf02f0bddefd19985372af79e906a38bc528b6 (diff)
downloadspark-5363ed71568c3e7c082146d654a9c669d692d894.tar.gz
spark-5363ed71568c3e7c082146d654a9c669d692d894.tar.bz2
spark-5363ed71568c3e7c082146d654a9c669d692d894.zip
[SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility
JIRA: https://issues.apache.org/jira/browse/SPARK-9361 Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate` and reuse it. In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert` should be moved to catalyst because it actually doesn't deal with execution details. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #7677 from viirya/refactor_aggregate and squashes the following commits: babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate 9a589d7 [Liang-Chi Hsieh] Fix scala style. 0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala167
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
3 files changed, 172 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9fb7623172..d08f553cef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
@@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
private[sql] case object Final extends AggregateMode
/**
- * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
new file mode 100644
index 0000000000..4a43318a95
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+ val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+ case array: ArrayType => true
+ case map: MapType => true
+ case struct: StructType => true
+ case _ => false
+ }
+
+ !hasComplexTypes
+ }
+
+ private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate if supportsGroupingKeySchema(p) =>
+ val converted = p.transformExpressionsDown {
+ case expressions.Average(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Average(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Count(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ // We do not support multiple COUNT DISTINCT columns for now.
+ case expressions.CountDistinct(children) if children.length == 1 =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Count(children.head),
+ mode = aggregate.Complete,
+ isDistinct = true)
+
+ case expressions.First(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.First(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Last(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Last(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Max(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Max(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Min(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Min(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.Sum(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = false)
+
+ case expressions.SumDistinct(child) =>
+ aggregate.AggregateExpression2(
+ aggregateFunction = aggregate.Sum(child),
+ mode = aggregate.Complete,
+ isDistinct = true)
+ }
+ // Check if there is any expressions.AggregateExpression1 left.
+ // If so, we cannot convert this plan.
+ val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+ // For every expressions, check if it contains AggregateExpression1.
+ expr.find {
+ case agg: expressions.AggregateExpression1 => true
+ case other => false
+ }.isDefined
+ }
+
+ // Check if there are multiple distinct columns.
+ val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg
+ }
+ }.toSet.toSeq
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val hasMultipleDistinctColumnSets =
+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ true
+ } else {
+ false
+ }
+
+ if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+ case other => None
+ }
+
+ def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+ // If the plan cannot be converted, we will do a final round check to see if the original
+ // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+ // we need to throw an exception.
+ val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression2 => agg.aggregateFunction
+ }
+ }.distinct
+ if (aggregateFunction2s.nonEmpty) {
+ // For functions implemented based on the new interface, prepare a list of function names.
+ val invalidFunctions = {
+ if (aggregateFunction2s.length > 1) {
+ s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+ s"and ${aggregateFunction2s.head.nodeName} are"
+ } else {
+ s"${aggregateFunction2s.head.nodeName} is"
+ }
+ }
+ val errorMessage =
+ s"${invalidFunctions} implemented based on the new Aggregate Function " +
+ s"interface and it cannot be used with functions implemented based on " +
+ s"the old Aggregate Function interface."
+ throw new AnalysisException(errorMessage)
+ }
+ }
+
+ def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+ case p: Aggregate =>
+ val converted = doConvert(p)
+ if (converted.isDefined) {
+ converted
+ } else {
+ checkInvalidAggregateFunction2(p)
+ None
+ }
+ case other => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ad5af19578..a67f8de6b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -219,6 +220,8 @@ case class Aggregate(
expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
}
+ lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
+
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}