aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala14
2 files changed, 35 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 40c06ed6d4..c222571a34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -32,6 +32,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
constraints
+ .union(inferAdditionalConstraints(constraints))
.union(constructIsNotNullConstraints(constraints))
.filter(constraint =>
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
@@ -64,6 +65,26 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}
/**
+ * Infers an additional set of constraints from a given set of equality constraints.
+ * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
+ * additional constraint of the form `b = 5`
+ */
+ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
+ var inferredConstraints = Set.empty[Expression]
+ constraints.foreach {
+ case eq @ EqualTo(l: Attribute, r: Attribute) =>
+ inferredConstraints ++= (constraints - eq).map(_ transform {
+ case a: Attribute if a.semanticEquals(l) => r
+ })
+ inferredConstraints ++= (constraints - eq).map(_ transform {
+ case a: Attribute if a.semanticEquals(r) => l
+ })
+ case _ => // No inference
+ }
+ inferredConstraints -- constraints
+ }
+
+ /**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e70d3794ab..a9375a740d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -158,6 +158,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
tr2.resolveQuoted("a", caseInsensitiveResolution).get,
+ tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
@@ -203,4 +204,17 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints.isEmpty)
}
+
+ test("infer additional constraints in filters") {
+ val tr = LocalRelation('a.int, 'b.int, 'c.int)
+
+ verifyConstraints(tr
+ .where('a.attr > 10 && 'a.attr === 'b.attr)
+ .analyze.constraints,
+ ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
+ resolveColumn(tr, "b") > 10,
+ resolveColumn(tr, "a") === resolveColumn(tr, "b"),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")))))
+ }
}