aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-06-02 00:20:52 -0700
committerReynold Xin <rxin@databricks.com>2015-06-02 00:20:52 -0700
commit0f80990bfac1e9969644952d1d8edaf7d26fb436 (patch)
tree726c7e11b17ecfbb059a75de1bd6d207007300d0 /sql
parent7b7f7b6c6fd903e2ecfc886d29eaa9df58adcfc3 (diff)
downloadspark-0f80990bfac1e9969644952d1d8edaf7d26fb436.tar.gz
spark-0f80990bfac1e9969644952d1d8edaf7d26fb436.tar.bz2
spark-0f80990bfac1e9969644952d1d8edaf7d26fb436.zip
[SPARK-8023][SQL] Add "deterministic" attribute to Expression to avoid collapsing nondeterministic projects.
This closes #6570. Author: Yin Huai <yhuai@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #6573 from rxin/deterministic and squashes the following commits: 356cd22 [Reynold Xin] Added unit test for the optimizer. da3fde1 [Reynold Xin] Merge pull request #6570 from yhuai/SPARK-8023 da56200 [Yin Huai] Comments. e38f264 [Yin Huai] Comment. f9d6a73 [Yin Huai] Add a deterministic method to Expression.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala41
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala4
6 files changed, 137 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index d199287844..adc6505d69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable
*/
def foldable: Boolean = false
+
+ /**
+ * Returns true when the current expression always return the same result for fixed input values.
+ */
+ // TODO: Need to define explicit input values vs implicit input values.
+ def deterministic: Boolean = true
+
def nullable: Boolean
+
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index 4f4f67a6e4..b2647124c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
*/
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
+ override def deterministic: Boolean = false
+
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c2818d957c..b25fb48f55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] {
* expressions into one single expression.
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
+
+ /** Returns true if any expression in projectList is non-deterministic. */
+ private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
+ projectList.exists(expr => expr.find(!_.deterministic).isDefined)
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case Project(projectList1, Project(projectList2, child)) =>
+ // We only collapse these two Projects if the child Project's expressions are all
+ // deterministic.
+ case Project(projectList1, Project(projectList2, child))
+ if !hasNondeterministic(projectList2) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
new file mode 100644
index 0000000000..151654bffb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Rand
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class ProjectCollapsingSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
+ Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int)
+
+ test("collapse two deterministic, independent projects into one") {
+ val query = testRelation
+ .select(('a + 1).as('a_plus_1), 'b)
+ .select('a_plus_1, ('b + 1).as('b_plus_1))
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("collapse two deterministic, dependent projects into one") {
+ val query = testRelation
+ .select(('a + 1).as('a_plus_1), 'b)
+ .select(('a_plus_1 + 1).as('a_plus_2), 'b)
+
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer = testRelation.select(
+ (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
+ 'b).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("do not collapse nondeterministic projects") {
+ val query = testRelation
+ .select(Rand(10).as('rand))
+ .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = query.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index b8bb1bff9e..bfba379d9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.scalatest.Matchers._
+import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -452,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest {
}
test("rand") {
- val randCol = testData.select('key, rand(5L).as("rand"))
+ val randCol = testData.select($"key", rand(5L).as("rand"))
randCol.columns.length should be (2)
val rows = randCol.collect()
rows.foreach { row =>
assert(row.getDouble(1) <= 1.0)
assert(row.getDouble(1) >= 0.0)
}
+
+ def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
+ val projects = df.queryExecution.executedPlan.collect {
+ case project: Project => project
+ }
+ assert(projects.size === expectedNumProjects)
+ }
+
+ // We first create a plan with two Projects.
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // Because Rand function is not deterministic, the column rand is not deterministic.
+ // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // and Project [key, Rand 5 AS rand]. The final plan still has two Projects.
+ val dfWithTwoProjects =
+ testData
+ .select($"key", (rand(5L) + 1).as("rand"))
+ .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2"))
+ checkNumProjects(dfWithTwoProjects, 2)
+
+ // Now, we add one more project rand1 - rand2 on top of the query plan.
+ // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated
+ // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2.
+ // So, the plan will be optimized from ...
+ // Project [(rand1 - rand2) AS (rand1 - rand2)]
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // to ...
+ // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)]
+ // Project [key, Rand 5 AS rand]
+ // LogicalRDD [key, value]
+ val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2")
+ checkNumProjects(dfWithThreeProjects, 2)
+ dfWithThreeProjects.collect().foreach { row =>
+ assert(row.getDouble(0) === 2.0 +- 0.0001)
+ }
}
test("randn") {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 64a49c83cb..1658bb93b0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
type UDFType = UDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient
@@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient