aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala55
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala79
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala173
4 files changed, 302 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b43b7ee71e..05f5bdbfc0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}
@@ -27,6 +27,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def output: Seq[Attribute]
/**
+ * Extracts the relevant constraints from a given set of constraints based on the attributes that
+ * appear in the [[outputSet]].
+ */
+ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
+ constraints
+ .union(constructIsNotNullConstraints(constraints))
+ .filter(constraint =>
+ constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
+ }
+
+ /**
+ * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions.
+ * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form
+ * `isNotNull(a)`
+ */
+ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
+ // Currently we only propagate constraints if the condition consists of equality
+ // and ranges. For all other cases, we return an empty set of constraints
+ constraints.map {
+ case EqualTo(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case GreaterThan(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case GreaterThanOrEqual(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case LessThan(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case LessThanOrEqual(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case _ =>
+ Set.empty[Expression]
+ }.foldLeft(Set.empty[Expression])(_ union _.toSet)
+ }
+
+ /**
+ * A sequence of expressions that describes the data property of the output rows of this
+ * operator. For example, if the output of this operator is column `a`, an example `constraints`
+ * can be `Set(a > 10, a < 20)`.
+ */
+ lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints)
+
+ /**
+ * This method can be overridden by any child class of QueryPlan to specify a set of constraints
+ * based on the given operator's constraint propagation logic. These constraints are then
+ * canonicalized and filtered automatically to contain only those attributes that appear in the
+ * [[outputSet]]
+ */
+ protected def validConstraints: Set[Expression] = Set.empty
+
+ /**
* Returns the set of attributes that are output by this node.
*/
def outputSet: AttributeSet = AttributeSet(output)
@@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
* Runs [[transform]] with `rule` on all expressions present in this query operator.
* Users should not expect a specific directionality. If a specific directionality is needed,
* transformExpressionsDown or transformExpressionsUp should be used.
+ *
* @param rule the rule to be applied to every expression in this operator.
*/
def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
@@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Runs [[transformDown]] with `rule` on all expressions present in this query operator.
+ *
* @param rule the rule to be applied to every expression in this operator.
*/
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
@@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Runs [[transformUp]] with `rule` on all expressions present in this query operator.
+ *
* @param rule the rule to be applied to every expression in this operator.
* @return
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 6d859551f8..d8944a4241 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan {
def child: LogicalPlan
override def children: Seq[LogicalPlan] = child :: Nil
+
+ override protected def validConstraints: Set[Expression] = child.constraints
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 16f4b355b1..8150ff8434 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -87,11 +87,27 @@ case class Generate(
}
}
-case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
+case class Filter(condition: Expression, child: LogicalPlan)
+ extends UnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output
+
+ override protected def validConstraints: Set[Expression] = {
+ child.constraints.union(splitConjunctivePredicates(condition).toSet)
+ }
}
-abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode
+abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
+
+ protected def leftConstraints: Set[Expression] = left.constraints
+
+ protected def rightConstraints: Set[Expression] = {
+ require(left.output.size == right.output.size)
+ val attributeRewrites = AttributeMap(right.output.zip(left.output))
+ right.constraints.map(_ transform {
+ case a: Attribute => attributeRewrites(a)
+ })
+ }
+}
private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
@@ -106,6 +122,10 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
+ override protected def validConstraints: Set[Expression] = {
+ leftConstraints.union(rightConstraints)
+ }
+
// Intersect are only resolved if they don't introduce ambiguous expression ids,
// since the Optimizer will convert Intersect to Join.
override lazy val resolved: Boolean =
@@ -119,6 +139,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output
+ override protected def validConstraints: Set[Expression] = leftConstraints
+
override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
@@ -157,13 +179,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}
+
+ /**
+ * Maps the constraints containing a given (original) sequence of attributes to those with a
+ * given (reference) sequence of attributes. Given the nature of union, we expect that the
+ * mapping between the original and reference sequences are symmetric.
+ */
+ private def rewriteConstraints(
+ reference: Seq[Attribute],
+ original: Seq[Attribute],
+ constraints: Set[Expression]): Set[Expression] = {
+ require(reference.size == original.size)
+ val attributeRewrites = AttributeMap(original.zip(reference))
+ constraints.map(_ transform {
+ case a: Attribute => attributeRewrites(a)
+ })
+ }
+
+ override protected def validConstraints: Set[Expression] = {
+ children
+ .map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
+ .reduce(_ intersect _)
+ }
}
case class Join(
- left: LogicalPlan,
- right: LogicalPlan,
- joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression])
+ extends BinaryNode with PredicateHelper {
override def output: Seq[Attribute] = {
joinType match {
@@ -180,6 +225,28 @@ case class Join(
}
}
+ override protected def validConstraints: Set[Expression] = {
+ joinType match {
+ case Inner if condition.isDefined =>
+ left.constraints
+ .union(right.constraints)
+ .union(splitConjunctivePredicates(condition.get).toSet)
+ case LeftSemi if condition.isDefined =>
+ left.constraints
+ .union(splitConjunctivePredicates(condition.get).toSet)
+ case Inner =>
+ left.constraints.union(right.constraints)
+ case LeftSemi =>
+ left.constraints
+ case LeftOuter =>
+ left.constraints
+ case RightOuter =>
+ right.constraints
+ case FullOuter =>
+ Set.empty[Expression]
+ }
+ }
+
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguous expression ids.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
new file mode 100644
index 0000000000..b5cf91394d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+
+class ConstraintPropagationSuite extends SparkFunSuite {
+
+ private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
+ tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
+
+ private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
+ val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
+ val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _))
+ if (missing.nonEmpty || extra.nonEmpty) {
+ fail(
+ s"""
+ |== FAIL: Constraints do not match ===
+ |Found: ${found.mkString(",")}
+ |Expected: ${expected.mkString(",")}
+ |== Result ==
+ |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")}
+ |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")}
+ """.stripMargin)
+ }
+ }
+
+ test("propagating constraints in filters") {
+ val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
+
+ verifyConstraints(tr
+ .where('a.attr > 10)
+ .analyze.constraints,
+ Set(resolveColumn(tr, "a") > 10,
+ IsNotNull(resolveColumn(tr, "a"))))
+
+ verifyConstraints(tr
+ .where('a.attr > 10)
+ .select('c.attr, 'a.attr)
+ .where('c.attr < 100)
+ .analyze.constraints,
+ Set(resolveColumn(tr, "a") > 10,
+ resolveColumn(tr, "c") < 100,
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c"))))
+ }
+
+ test("propagating constraints in union") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
+ val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
+
+ assert(tr1
+ .where('a.attr > 10)
+ .unionAll(tr2.where('e.attr > 10)
+ .unionAll(tr3.where('i.attr > 10)))
+ .analyze.constraints.isEmpty)
+
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .unionAll(tr2.where('d.attr > 10)
+ .unionAll(tr3.where('g.attr > 10)))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ IsNotNull(resolveColumn(tr1, "a"))))
+ }
+
+ test("propagating constraints in intersect") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
+
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .intersect(tr2.where('b.attr < 100))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ resolveColumn(tr1, "b") < 100,
+ IsNotNull(resolveColumn(tr1, "a")),
+ IsNotNull(resolveColumn(tr1, "b"))))
+ }
+
+ test("propagating constraints in except") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .except(tr2.where('b.attr < 100))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ IsNotNull(resolveColumn(tr1, "a"))))
+ }
+
+ test("propagating constraints in inner join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
+ tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
+ tr2.resolveQuoted("a", caseInsensitiveResolution).get,
+ IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
+ IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in left-semi join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in left-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in right-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
+ IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in full-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ assert(tr1.where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints.isEmpty)
+ }
+}