aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala8
2 files changed, 18 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 09725473a3..b0a4145f37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -293,15 +293,19 @@ abstract class UnaryNode extends LogicalPlan {
* expressions with the corresponding alias
*/
protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
- projectList.flatMap {
+ var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
+ projectList.foreach {
case a @ Alias(e, _) =>
- child.constraints.map(_ transform {
+ // For every alias in `projectList`, replace the reference in constraints by its attribute.
+ allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
- }).union(Set(EqualNullSafe(e, a.toAttribute)))
- case _ =>
- Set.empty[Expression]
- }.toSet
+ })
+ allConstraints += EqualNullSafe(e, a.toAttribute)
+ case _ => // Don't change.
+ }
+
+ allConstraints -- child.constraints
}
override protected def validConstraints: Set[Expression] = child.constraints
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 8d6a49a8a3..8068ce922e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -128,8 +128,16 @@ class ConstraintPropagationSuite extends SparkFunSuite {
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
+ resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
+
+ val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y))
+ verifyConstraints(multiAlias.analyze.constraints,
+ ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")),
+ IsNotNull(resolveColumn(multiAlias.analyze, "y")),
+ resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10))
+ )
}
test("propagating constraints in union") {