diff options
2 files changed, 35 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 40c06ed6d4..c222571a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -32,6 +32,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { constraints + .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) @@ -64,6 +65,26 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5` + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(l) => r + }) + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(r) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints + } + + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e70d3794ab..a9375a740d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -158,6 +158,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, tr1.resolveQuoted("a", caseInsensitiveResolution).get === tr2.resolveQuoted("a", caseInsensitiveResolution).get, + tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) @@ -203,4 +204,17 @@ class ConstraintPropagationSuite extends SparkFunSuite { .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints.isEmpty) } + + test("infer additional constraints in filters") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr + .where('a.attr > 10 && 'a.attr === 'b.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "b") > 10, + resolveColumn(tr, "a") === resolveColumn(tr, "b"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) + } } |