diff options
author | jiangxingbo <jiangxb1987@gmail.com> | 2016-10-26 20:12:20 +0200 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-10-26 20:12:20 +0200 |
commit | fa7d9d70825a6816495d239da925d0087f7cb94f (patch) | |
tree | 2b0cbdef51c320e69c4c55507ae15724a3773a22 /sql/catalyst/src | |
parent | 7ac70e7ba8d610a45c21a70dc28e4c989c19451b (diff) | |
download | spark-fa7d9d70825a6816495d239da925d0087f7cb94f.tar.gz spark-fa7d9d70825a6816495d239da925d0087f7cb94f.tar.bz2 spark-fa7d9d70825a6816495d239da925d0087f7cb94f.zip |
[SPARK-18063][SQL] Failed to infer constraints over multiple aliases
## What changes were proposed in this pull request?
The `UnaryNode.getAliasedConstraints` function fails to replace all expressions by their alias where constraints contains more than one expression to be replaced.
For example:
```
val tr = LocalRelation('a.int, 'b.string, 'c.int)
val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y))
multiAlias.analyze.constraints
```
currently outputs:
```
ExpressionSet(Seq(
IsNotNull(resolveColumn(multiAlias.analyze, "x")),
IsNotNull(resolveColumn(multiAlias.analyze, "y"))
)
```
The constraint `resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)` is missing.
## How was this patch tested?
Add new test cases in `ConstraintPropagationSuite`.
Author: jiangxingbo <jiangxb1987@gmail.com>
Closes #15597 from jiangxb1987/alias-constraints.
Diffstat (limited to 'sql/catalyst/src')
2 files changed, 18 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 09725473a3..b0a4145f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -293,15 +293,19 @@ abstract class UnaryNode extends LogicalPlan { * expressions with the corresponding alias */ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { - projectList.flatMap { + var allConstraints = child.constraints.asInstanceOf[Set[Expression]] + projectList.foreach { case a @ Alias(e, _) => - child.constraints.map(_ transform { + // For every alias in `projectList`, replace the reference in constraints by its attribute. + allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet + }) + allConstraints += EqualNullSafe(e, a.toAttribute) + case _ => // Don't change. + } + + allConstraints -- child.constraints } override protected def validConstraints: Set[Expression] = child.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 8d6a49a8a3..8068ce922e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -128,8 +128,16 @@ class ConstraintPropagationSuite extends SparkFunSuite { ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + + val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) + verifyConstraints(multiAlias.analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), + IsNotNull(resolveColumn(multiAlias.analyze, "y")), + resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) + ) } test("propagating constraints in union") { |