aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-01-12 20:21:04 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-12 20:21:04 +0800
commit871d266649ddfed38c64dfda7158d8bb58d4b979 (patch)
treeaec91eff39e31040e8a380430e20b4d31fdbc436 /sql
parentc71b25481aa5f7bc27d5c979e66bed54cd46b97e (diff)
downloadspark-871d266649ddfed38c64dfda7158d8bb58d4b979.tar.gz
spark-871d266649ddfed38c64dfda7158d8bb58d4b979.tar.bz2
spark-871d266649ddfed38c64dfda7158d8bb58d4b979.zip
[SPARK-18969][SQL] Support grouping by nondeterministic expressions
## What changes were proposed in this pull request? Currently nondeterministic expressions are allowed in `Aggregate`(see the [comment](https://github.com/apache/spark/blob/v2.0.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L249-L251)), but the `PullOutNondeterministic` analyzer rule failed to handle `Aggregate`, this PR fixes it. close https://github.com/apache/spark/pull/16379 There is still one remaining issue: `SELECT a + rand() FROM t GROUP BY a + rand()` is not allowed, because the 2 `rand()` are different(we generate random seed as the default seed for `rand()`). https://issues.apache.org/jira/browse/SPARK-19035 is tracking this issue. ## How was this patch tested? a new test suite Author: Wenchen Fan <wenchen@databricks.com> Closes #16404 from cloud-fan/groupby.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala56
-rw-r--r--sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out10
3 files changed, 86 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d461531217..3c58832d34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2008,28 +2008,37 @@ class Analyzer(
case p: Project => p
case f: Filter => f
+ case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
+ val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
+ val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
+ a.transformExpressions { case e =>
+ nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
+ }.copy(child = newChild)
+
// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
- val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr =>
- val leafNondeterministic = expr.collect {
- case n: Nondeterministic => n
- }
- leafNondeterministic.map { e =>
- val ne = e match {
- case n: NamedExpression => n
- case _ => Alias(e, "_nondeterministic")(isGenerated = true)
- }
- new TreeNodeRef(e) -> ne
- }
- }.toMap
+ val nondeterToAttr = getNondeterToAttr(p.expressions)
val newPlan = p.transformExpressions { case e =>
- nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
+ nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}
- val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
+ val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child)
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
+
+ private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = {
+ exprs.filterNot(_.deterministic).flatMap { expr =>
+ val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
+ leafNondeterministic.distinct.map { e =>
+ val ne = e match {
+ case n: NamedExpression => n
+ case _ => Alias(e, "_nondeterministic")(isGenerated = true)
+ }
+ e -> ne
+ }
+ }.toMap
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
new file mode 100644
index 0000000000..72e10eadf7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+/**
+ * Test suite for moving non-deterministic expressions into Project.
+ */
+class PullOutNondeterministicSuite extends AnalysisTest {
+
+ private lazy val a = 'a.int
+ private lazy val b = 'b.int
+ private lazy val r = LocalRelation(a, b)
+ private lazy val rnd = Rand(10).as('_nondeterministic)
+ private lazy val rndref = rnd.toAttribute
+
+ test("no-op on filter") {
+ checkAnalysis(
+ r.where(Rand(10) > Literal(1.0)),
+ r.where(Rand(10) > Literal(1.0))
+ )
+ }
+
+ test("sort") {
+ checkAnalysis(
+ r.sortBy(SortOrder(Rand(10), Ascending)),
+ r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
+ )
+ }
+
+ test("aggregate") {
+ checkAnalysis(
+ r.groupBy(Rand(10))(Rand(10).as("rnd")),
+ r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
+ )
+ }
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index 9c3a145f3a..c64520ff93 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and aggregate functions are not al
-- !query 13
select a, rand(0), sum(b) from data group by a, 2
-- !query 13 schema
-struct<>
+struct<a:int,rand(0):double,sum(b):bigint>
-- !query 13 output
-org.apache.spark.sql.AnalysisException
-nondeterministic expression rand(0) should not appear in grouping expression.;
+1 0.4048454303385226 2
+1 0.8446490682263027 1
+2 0.5871875724155838 1
+2 0.8865128837019473 2
+3 0.742083829230211 1
+3 0.9179913208300406 2
-- !query 14