aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-02-29 10:10:04 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-29 10:10:04 -0800
commitbc65f60ef7c920db756bfe643f7edbdf3593a989 (patch)
tree8658cdea72665ecc10f3e50104529eb155b2726c /sql
parent02aa499dfb71bc9571bebb79e6383842e4f48143 (diff)
downloadspark-bc65f60ef7c920db756bfe643f7edbdf3593a989.tar.gz
spark-bc65f60ef7c920db756bfe643f7edbdf3593a989.tar.bz2
spark-bc65f60ef7c920db756bfe643f7edbdf3593a989.zip
[SPARK-13544][SQL] Rewrite/Propagate Constraints for Aliases in Aggregate
#### What changes were proposed in this pull request? After analysis by Analyzer, two operators could have alias. They are `Project` and `Aggregate`. So far, we only rewrite and propagate constraints if `Alias` is defined in `Project`. This PR is to resolve this issue in `Aggregate`. #### How was this patch tested? Added a test case for `Aggregate` in `ConstraintPropagationSuite`. marmbrus sameeragarwal Author: gatorsmile <gatorsmile@gmail.com> Closes #11422 from gatorsmile/validConstraintsInUnaryNodes.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala30
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala15
3 files changed, 38 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8095083f33..31e775d60f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
+ /**
+ * Generates an additional set of aliased constraints by replacing the original constraint
+ * expressions with the corresponding alias
+ */
+ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
+ projectList.flatMap {
+ case a @ Alias(e, _) =>
+ child.constraints.map(_ transform {
+ case expr: Expression if expr.semanticEquals(e) =>
+ a.toAttribute
+ }).union(Set(EqualNullSafe(e, a.toAttribute)))
+ case _ =>
+ Set.empty[Expression]
+ }.toSet
+ }
+
override protected def validConstraints: Set[Expression] = child.constraints
override def statistics: Statistics = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5d2a65b716..e81a0f9487 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}
- /**
- * Generates an additional set of aliased constraints by replacing the original constraint
- * expressions with the corresponding alias
- */
- private def getAliasedConstraints: Set[Expression] = {
- projectList.flatMap {
- case a @ Alias(e, _) =>
- child.constraints.map(_ transform {
- case expr: Expression if expr.semanticEquals(e) =>
- a.toAttribute
- }).union(Set(EqualNullSafe(e, a.toAttribute)))
- case _ =>
- Set.empty[Expression]
- }.toSet
- }
-
- override def validConstraints: Set[Expression] = {
- child.constraints.union(getAliasedConstraints)
- }
+ override def validConstraints: Set[Expression] =
+ child.constraints.union(getAliasedConstraints(projectList))
}
/**
@@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
- override protected def validConstraints: Set[Expression] = {
+ override protected def validConstraints: Set[Expression] =
child.constraints.union(splitConjunctivePredicates(condition).toSet)
- }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
- override protected def validConstraints: Set[Expression] = {
+ override protected def validConstraints: Set[Expression] =
leftConstraints.union(rightConstraints)
- }
// Intersect are only resolved if they don't introduce ambiguous expression ids,
// since the Optimizer will convert Intersect to Join.
@@ -442,6 +423,9 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows
+ override def validConstraints: Set[Expression] =
+ child.constraints.union(getAliasedConstraints(aggregateExpressions))
+
override def statistics: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(sizeInBytes = 1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 373b1ffa83..b68432b1a1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "c"))))
}
+ test("propagating constraints in aggregate") {
+ val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
+ .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze
+
+ verifyConstraints(aliasedRelation.analyze.constraints,
+ Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
+ resolveColumn(aliasedRelation.analyze, "a") < 5,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
+ }
+
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)