aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-08-24 18:10:51 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-24 18:11:04 -0700
commit228e429ebf1f367de9087f74cf3ff43bbd32f382 (patch)
tree0b254ce2cc0025c9edefa20c255f920af41b90a8
parent8ca8bdd015c53ff0c4705886545fc30eef8b8359 (diff)
downloadspark-228e429ebf1f367de9087f74cf3ff43bbd32f382.tar.gz
spark-228e429ebf1f367de9087f74cf3ff43bbd32f382.tar.bz2
spark-228e429ebf1f367de9087f74cf3ff43bbd32f382.zip
[SPARK-10165] [SQL] Await child resolution in ResolveFunctions
Currently, we eagerly attempt to resolve functions, even before their children are resolved. However, this is not valid in cases where we need to know the types of the input arguments (i.e. when resolving Hive UDFs). As a fix, this PR delays function resolution until the functions children are resolved. This change also necessitates a change to the way we resolve aggregate expressions that are not in aggregate operators (e.g., in `HAVING` or `ORDER BY` clauses). Specifically, we can't assume that these misplaced functions will be resolved, allowing us to differentiate aggregate functions from normal functions. To compensate for this change we now attempt to resolve these unresolved expressions in the context of the aggregate operator, before checking to see if any aggregate expressions are present. Author: Michael Armbrust <michael@databricks.com> Closes #8371 from marmbrus/hiveUDFResolution. (cherry picked from commit 2bf338c626e9d97ccc033cfadae8b36a82c66fd1) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala116
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala5
2 files changed, 77 insertions, 44 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d0eb9c2c90..1a5de15c61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -78,7 +78,7 @@ class Analyzer(
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
- UnresolvedHavingClauseAttributes ::
+ ResolveAggregateFunctions ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -452,37 +452,6 @@ class Analyzer(
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
- case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
- if !s.resolved && a.resolved =>
- // A small hack to create an object that will allow us to resolve any references that
- // refer to named expressions that are present in the grouping expressions.
- val groupingRelation = LocalRelation(
- grouping.collect { case ne: NamedExpression => ne.toAttribute }
- )
-
- // Find sort attributes that are projected away so we can temporarily add them back in.
- val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)
-
- // Find aggregate expressions and evaluate them early, since they can't be evaluated in a
- // Sort.
- val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
- case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
- val aliased = Alias(aggOrdering.child, "_aggOrdering")()
- (aggOrdering.copy(child = aliased.toAttribute), Some(aliased))
-
- case other => (other, None)
- }.unzip
-
- val missing = missingAttr ++ aliasedAggregateList.flatten
-
- if (missing.nonEmpty) {
- // Add missing grouping exprs and then project them away after the sort.
- Project(a.output,
- Sort(withAggsRemoved, global,
- Aggregate(grouping, aggs ++ missing, child)))
- } else {
- s // Nothing we can do here. Return original plan.
- }
}
/**
@@ -515,6 +484,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
+ case u if !u.childrenResolved => u // Skip until children are resolved.
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
@@ -559,21 +529,79 @@ class Analyzer(
}
/**
- * This rule finds expressions in HAVING clause filters that depend on
- * unresolved attributes. It pushes these expressions down to the underlying
- * aggregates and then projects them away above the filter.
+ * This rule finds aggregate expressions that are not in an aggregate operator. For example,
+ * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the
+ * underlying aggregate operator and then projected away after the original operator.
*/
- object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
+ object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
- if aggregate.resolved && containsAggregate(havingCondition) =>
-
- val evaluatedCondition = Alias(havingCondition, "havingCondition")()
- val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
+ case filter @ Filter(havingCondition,
+ aggregate @ Aggregate(grouping, originalAggExprs, child))
+ if aggregate.resolved && !filter.resolved =>
+
+ // Try resolving the condition of the filter as though it is in the aggregate clause
+ val aggregatedCondition =
+ Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child)
+ val resolvedOperator = execute(aggregatedCondition)
+ def resolvedAggregateFilter =
+ resolvedOperator
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions.head
+
+ // If resolution was successful and we see the filter has an aggregate in it, add it to
+ // the original aggregate operator.
+ if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
+ val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+ Project(aggregate.output,
+ Filter(resolvedAggregateFilter.toAttribute,
+ aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ } else {
+ filter
+ }
- Project(aggregate.output,
- Filter(evaluatedCondition.toAttribute,
- aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ case sort @ Sort(sortOrder, global,
+ aggregate @ Aggregate(grouping, originalAggExprs, child))
+ if aggregate.resolved && !sort.resolved =>
+
+ // Try resolving the ordering as though it is in the aggregate clause.
+ try {
+ val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
+ val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
+ val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
+ def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
+
+ // Expressions that have an aggregate can be pushed down.
+ val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
+
+ // Attribute references, that are missing from the order but are present in the grouping
+ // expressions can also be pushed down.
+ val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
+ val missingAttributes = requiredAttributes -- aggregate.outputSet
+ val validPushdownAttributes =
+ missingAttributes.filter(a => grouping.exists(a.semanticEquals))
+
+ // If resolution was successful and we see the ordering either has an aggregate in it or
+ // it is missing something that is projected away by the aggregate, add the ordering
+ // the original aggregate operator.
+ if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
+ val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
+ case (order, evaluated) => order.copy(child = evaluated.toAttribute)
+ }
+ val aggExprsWithOrdering: Seq[NamedExpression] =
+ resolvedAggregateOrdering ++ originalAggExprs
+
+ Project(aggregate.output,
+ Sort(evaluatedOrderings, global,
+ aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
+ } else {
+ sort
+ }
+ } catch {
+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+ // just return the original plan.
+ case ae: AnalysisException => sort
+ }
}
protected def containsAggregate(condition: Expression): Boolean = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 10f2902e5e..b03a351323 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -276,6 +276,11 @@ class HiveUDFSuite extends QueryTest {
checkAnswer(
sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"),
Seq(Row("hello world"), Row("hello goodbye")))
+
+ checkAnswer(
+ sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"),
+ Seq(Row(" hello world"), Row(" hello goodbye")))
+
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF")
TestHive.reset()