diff options
author | Wenchen Fan <wenchen@databricks.com> | 2017-01-12 20:21:04 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-01-12 20:21:04 +0800 |
commit | 871d266649ddfed38c64dfda7158d8bb58d4b979 (patch) | |
tree | aec91eff39e31040e8a380430e20b4d31fdbc436 /sql | |
parent | c71b25481aa5f7bc27d5c979e66bed54cd46b97e (diff) | |
download | spark-871d266649ddfed38c64dfda7158d8bb58d4b979.tar.gz spark-871d266649ddfed38c64dfda7158d8bb58d4b979.tar.bz2 spark-871d266649ddfed38c64dfda7158d8bb58d4b979.zip |
[SPARK-18969][SQL] Support grouping by nondeterministic expressions
## What changes were proposed in this pull request?
Currently nondeterministic expressions are allowed in `Aggregate`(see the [comment](https://github.com/apache/spark/blob/v2.0.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L249-L251)), but the `PullOutNondeterministic` analyzer rule failed to handle `Aggregate`, this PR fixes it.
close https://github.com/apache/spark/pull/16379
There is still one remaining issue: `SELECT a + rand() FROM t GROUP BY a + rand()` is not allowed, because the 2 `rand()` are different(we generate random seed as the default seed for `rand()`). https://issues.apache.org/jira/browse/SPARK-19035 is tracking this issue.
## How was this patch tested?
a new test suite
Author: Wenchen Fan <wenchen@databricks.com>
Closes #16404 from cloud-fan/groupby.
Diffstat (limited to 'sql')
3 files changed, 86 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d461531217..3c58832d34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2008,28 +2008,37 @@ class Analyzer( case p: Project => p case f: Filter => f + case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => + val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) + val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) + a.transformExpressions { case e => + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + }.copy(child = newChild) + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - } - leafNondeterministic.map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) - } - new TreeNodeRef(e) -> ne - } - }.toMap + val nondeterToAttr = getNondeterToAttr(p.expressions) val newPlan = p.transformExpressions { case e => - nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) } - val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } + + private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { + exprs.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { case n: Nondeterministic => n } + leafNondeterministic.distinct.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")(isGenerated = true) + } + e -> ne + } + }.toMap + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala new file mode 100644 index 0000000000..72e10eadf7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * Test suite for moving non-deterministic expressions into Project. + */ +class PullOutNondeterministicSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val b = 'b.int + private lazy val r = LocalRelation(a, b) + private lazy val rnd = Rand(10).as('_nondeterministic) + private lazy val rndref = rnd.toAttribute + + test("no-op on filter") { + checkAnalysis( + r.where(Rand(10) > Literal(1.0)), + r.where(Rand(10) > Literal(1.0)) + ) + } + + test("sort") { + checkAnalysis( + r.sortBy(SortOrder(Rand(10), Ascending)), + r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b) + ) + } + + test("aggregate") { + checkAnalysis( + r.groupBy(Rand(10))(Rand(10).as("rnd")), + r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd")) + ) + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 9c3a145f3a..c64520ff93 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and aggregate functions are not al -- !query 13 select a, rand(0), sum(b) from data group by a, 2 -- !query 13 schema -struct<> +struct<a:int,rand(0):double,sum(b):bigint> -- !query 13 output -org.apache.spark.sql.AnalysisException -nondeterministic expression rand(0) should not appear in grouping expression.; +1 0.4048454303385226 2 +1 0.8446490682263027 1 +2 0.5871875724155838 1 +2 0.8865128837019473 2 +3 0.742083829230211 1 +3 0.9179913208300406 2 -- !query 14 |