aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-02-19 14:48:34 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-19 14:48:34 -0800
commit091f6a7830bbee01fa580fbb0336b9f4fcac0dfa (patch)
tree818ed942896580321b80015c358ff1bd24796f44 /sql
parent14844118b596a93dbc28b442a7ea2b58fa4df648 (diff)
downloadspark-091f6a7830bbee01fa580fbb0336b9f4fcac0dfa.tar.gz
spark-091f6a7830bbee01fa580fbb0336b9f4fcac0dfa.tar.bz2
spark-091f6a7830bbee01fa580fbb0336b9f4fcac0dfa.zip
[SPARK-13091][SQL] Rewrite/Propagate constraints for Aliases
This PR adds support for rewriting constraints if there are aliases in the query plan. For e.g., if there is a query of form `SELECT a, a AS b`, any constraints on `a` now also apply to `b`. JIRA: https://issues.apache.org/jira/browse/SPARK-13091 cc marmbrus Author: Sameer Agarwal <sameer@databricks.com> Closes #11144 from sameeragarwal/alias.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala20
2 files changed, 39 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 502d898fea..7d155ac183 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}
+
+ /**
+ * Generates an additional set of aliased constraints by replacing the original constraint
+ * expressions with the corresponding alias
+ */
+ private def getAliasedConstraints: Set[Expression] = {
+ projectList.flatMap {
+ case a @ Alias(e, _) =>
+ child.constraints.map(_ transform {
+ case expr: Expression if expr.semanticEquals(e) =>
+ a.toAttribute
+ }).union(Set(EqualNullSafe(e, a.toAttribute)))
+ case _ =>
+ Set.empty[Expression]
+ }.toSet
+ }
+
+ override def validConstraints: Set[Expression] = {
+ child.constraints.union(getAliasedConstraints)
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index b5cf91394d..373b1ffa83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
- tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
+ resolveColumn(tr.analyze, columnName)
+
+ private def resolveColumn(plan: LogicalPlan, columnName: String): Expression =
+ plan.resolveQuoted(columnName, caseInsensitiveResolution).get
private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
@@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "c"))))
}
+ test("propagating constraints in aliases") {
+ val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+ assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty)
+
+ val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))
+
+ verifyConstraints(aliasedRelation.analyze.constraints,
+ Set(resolveColumn(aliasedRelation.analyze, "x") > 10,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
+ resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
+ resolveColumn(aliasedRelation.analyze, "z") > 10,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))
+ }
+
test("propagating constraints in union") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('d.int, 'e.int, 'f.int)