From 85e68b4bea3e4ad2e4063334dbf5b11af197d2ce Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 12:29:54 -0700 Subject: [SPARK-14562] [SQL] improve constraints propagation in Union ## What changes were proposed in this pull request? Currently, Union only takes intersect of the constraints from it's children, all others are dropped, we should try to merge them together. This PR try to merge the constraints that have the same reference but came from different children, for example: `a > 10` and `a < 100` could be merged as `a > 10 || a < 100`. ## How was this patch tested? Added more cases in existing test. Author: Davies Liu Closes #12328 from davies/union_const. --- .../sql/catalyst/plans/logical/basicOperators.scala | 16 +++++++++++++++- .../sql/catalyst/plans/ConstraintPropagationSuite.scala | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3353beb09..d4fc9e4da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 49c1353efb..81cc6b123c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") { -- cgit v1.2.3