diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-09-08 12:05:41 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-09-08 12:05:41 -0700 |
commit | 5fd57955ef477347408f68eb1cb6ad1881fdb6e0 (patch) | |
tree | 5680c9e6446883bd2cdb90ff5a0e67be1d3b4681 | |
parent | 5b2192e846b843d8a0cb9427d19bb677431194a0 (diff) | |
download | spark-5fd57955ef477347408f68eb1cb6ad1881fdb6e0.tar.gz spark-5fd57955ef477347408f68eb1cb6ad1881fdb6e0.tar.bz2 spark-5fd57955ef477347408f68eb1cb6ad1881fdb6e0.zip |
[SPARK-10316] [SQL] respect nondeterministic expressions in PhysicalOperation
We did a lot of special handling for non-deterministic expressions in `Optimizer`. However, `PhysicalOperation` just collects all Projects and Filters and mess it up. We should respect the operators order caused by non-deterministic expressions in `PhysicalOperation`.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #8486 from cloud-fan/fix.
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 38 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 12 |
2 files changed, 20 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index e8abcd63f7..5353779951 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.planning -import scala.annotation.tailrec - import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ @@ -26,27 +24,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ /** - * A pattern that matches any number of filter operations on top of another relational operator. - * Adjacent filter operators are collected and their conditions are broken up and returned as a - * sequence of conjunctive predicates. - * - * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the - * output and a relational operator. - */ -object FilteredOperation extends PredicateHelper { - type ReturnType = (Seq[Expression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) - - @tailrec - private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { - case Filter(condition, child) => - collectFilters(filters ++ splitConjunctivePredicates(condition), child) - case other => (filters, other) - } -} - -/** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned * together with the top project operator. @@ -62,8 +39,9 @@ object PhysicalOperation extends PredicateHelper { } /** - * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two - * examples for alias in-lining/substitution. Before: + * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * Here are two examples for alias in-lining/substitution. + * Before: * {{{ * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 @@ -74,15 +52,15 @@ object PhysicalOperation extends PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - def collectProjectsAndFilters(plan: LogicalPlan): + private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case Project(fields, child) => + case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) - case Filter(condition, child) => + case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) @@ -91,11 +69,11 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { + private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b5b9f11785..dbed4fc247 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,6 +22,8 @@ import java.io.File import scala.language.postfixOps import scala.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -895,4 +897,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .orderBy(sum('j)) checkAnswer(query, Row(1, 2)) } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } } |