aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@163.com>2015-09-22 12:14:15 -0700
committerYin Huai <yhuai@databricks.com>2015-09-22 12:14:59 -0700
commit5017c685f484ec256101d1d33bad11d9e0c0f641 (patch)
treeebdcc32f0fa8edf7fd55d73d1d46dd0711cea2e3 /sql
parent1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 (diff)
downloadspark-5017c685f484ec256101d1d33bad11d9e0c0f641.tar.gz
spark-5017c685f484ec256101d1d33bad11d9e0c0f641.tar.bz2
spark-5017c685f484ec256101d1d33bad11d9e0c0f641.zip
[SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations
https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan <cloud0fan@163.com> Closes #8858 from cloud-fan/non-deter.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala41
3 files changed, 93 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 324f40a051..63602eaa8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
* Intersect:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * because we will not have non-deterministic expressions.
+ * with deterministic condition.
*
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * because we will not have non-deterministic expressions.
+ * with deterministic condition.
*/
-object SetOperationPushDown extends Rule[LogicalPlan] {
+object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
@@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
result.asInstanceOf[A]
}
+ /**
+ * Splits the condition expression into small conditions by `And`, and partition them by
+ * deterministic, and finally recombine them by `And`. It returns an expression containing
+ * all deterministic expressions (the first field of the returned Tuple2) and an expression
+ * containing all non-deterministic expressions (the second field of the returned Tuple2).
+ */
+ private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
+ val andConditions = splitConjunctivePredicates(condition)
+ andConditions.partition(_.deterministic) match {
+ case (deterministic, nondeterministic) =>
+ deterministic.reduceOption(And).getOrElse(Literal(true)) ->
+ nondeterministic.reduceOption(And).getOrElse(Literal(true))
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down filter into union
case Filter(condition, u @ Union(left, right)) =>
+ val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(u)
- Union(
- Filter(condition, left),
- Filter(pushToRight(condition, rewrites), right))
-
- // Push down projection through UNION ALL
- case Project(projectList, u @ Union(left, right)) =>
- val rewrites = buildRewrites(u)
- Union(
- Project(projectList, left),
- Project(projectList.map(pushToRight(_, rewrites)), right))
+ Filter(nondeterministic,
+ Union(
+ Filter(deterministic, left),
+ Filter(pushToRight(deterministic, rewrites), right)
+ )
+ )
+
+ // Push down deterministic projection through UNION ALL
+ case p @ Project(projectList, u @ Union(left, right)) =>
+ if (projectList.forall(_.deterministic)) {
+ val rewrites = buildRewrites(u)
+ Union(
+ Project(projectList, left),
+ Project(projectList.map(pushToRight(_, rewrites)), right))
+ } else {
+ p
+ }
// Push down filter through INTERSECT
case Filter(condition, i @ Intersect(left, right)) =>
+ val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(i)
- Intersect(
- Filter(condition, left),
- Filter(pushToRight(condition, rewrites), right))
+ Filter(nondeterministic,
+ Intersect(
+ Filter(deterministic, left),
+ Filter(pushToRight(deterministic, rewrites), right)
+ )
+ )
// Push down filter through EXCEPT
case Filter(condition, e @ Except(left, right)) =>
+ val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(e)
- Except(
- Filter(condition, left),
- Filter(pushToRight(condition, rewrites), right))
+ Filter(nondeterministic,
+ Except(
+ Filter(deterministic, left),
+ Filter(pushToRight(deterministic, rewrites), right)
+ )
+ )
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
index 3fca47a023..1595ad9327 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
@@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
- SetOperationPushDown) :: Nil
+ SetOperationPushDown,
+ SimplifyFilters) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1370713975..d919877746 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(intersect.count() === 30)
assert(except.count() === 70)
}
+
+ test("SPARK-10740: handle nondeterministic expressions correctly for set operations") {
+ val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
+ val df2 = (1 to 10).map(Tuple1.apply).toDF("i")
+
+ // When generating expected results at here, we need to follow the implementation of
+ // Rand expression.
+ def expected(df: DataFrame): Seq[Row] = {
+ df.rdd.collectPartitions().zipWithIndex.flatMap {
+ case (data, index) =>
+ val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
+ data.filter(_.getInt(0) < rng.nextDouble() * 10)
+ }
+ }
+
+ val union = df1.unionAll(df2)
+ checkAnswer(
+ union.filter('i < rand(7) * 10),
+ expected(union)
+ )
+ checkAnswer(
+ union.select(rand(7)),
+ union.rdd.collectPartitions().zipWithIndex.flatMap {
+ case (data, index) =>
+ val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
+ data.map(_ => rng.nextDouble()).map(i => Row(i))
+ }
+ )
+
+ val intersect = df1.intersect(df2)
+ checkAnswer(
+ intersect.filter('i < rand(7) * 10),
+ expected(intersect)
+ )
+
+ val except = df1.except(df2)
+ checkAnswer(
+ except.filter('i < rand(7) * 10),
+ expected(except)
+ )
+ }
}