aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala14
2 files changed, 29 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3353beb09..d4fc9e4da9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
})
}
+ private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
+ val common = a.intersect(b)
+ // The constraint with only one reference could be easily inferred as predicate
+ // Grouping the constraints by it's references so we can combine the constraints with same
+ // reference together
+ val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2)
+ val others = (othera.keySet intersect otherb.keySet).map { attr =>
+ Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
+ }
+ common ++ others
+ }
+
override protected def validConstraints: Set[Expression] = {
children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
- .reduce(_ intersect _)
+ .reduce(merge(_, _))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 49c1353efb..81cc6b123c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
+
+ val a = resolveColumn(tr1, "a")
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .union(tr2.where('d.attr > 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
+
+ val b = resolveColumn(tr1, "b")
+ verifyConstraints(tr1
+ .where('a.attr > 10 && 'b.attr < 10)
+ .union(tr2.where('d.attr > 11 && 'e.attr < 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
}
test("propagating constraints in intersect") {