aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-02-24 19:43:00 -0800
committerYin Huai <yhuai@databricks.com>2016-02-24 19:43:00 -0800
commit2b042577fb077865c3fce69c9d4eda22fde92673 (patch)
tree201ba584e3e4ef81be598e8c4a28a9e4db261e06 /sql
parent5a7af9e7ac85e04aa4a420bc2887207bfa18f792 (diff)
downloadspark-2b042577fb077865c3fce69c9d4eda22fde92673.tar.gz
spark-2b042577fb077865c3fce69c9d4eda22fde92673.tar.bz2
spark-2b042577fb077865c3fce69c9d4eda22fde92673.zip
[SPARK-13092][SQL] Add ExpressionSet for constraint tracking
This PR adds a new abstraction called an `ExpressionSet` which attempts to canonicalize expressions to remove cosmetic differences. Deterministic expressions that are in the set after canonicalization will always return the same answer given the same input (i.e. false positives should not be possible). However, it is possible that two canonical expressions that are not equal will in fact return the same answer given any input (i.e. false negatives are possible). ```scala val set = AttributeSet('a + 1 :: 1 + 'a :: Nil) set.iterator => Iterator('a + 1) set.contains('a + 1) => true set.contains(1 + 'a) => true set.contains('a + 2) => false ``` Other relevant changes include: - Since this concept overlaps with the existing `semanticEquals` and `semanticHash`, those functions are also ported to this new infrastructure. - A memoized `canonicalized` version of the expression is added as a `lazy val` to `Expression` and is used by both `semanticEquals` and `ExpressionSet`. - A set of unit tests for `ExpressionSet` are added - Tests which expect `semanticEquals` to be less intelligent than it now is are updated. As a followup, we should consider auditing the places where we do `O(n)` `semanticEquals` operations and replace them with `ExpressionSet`. We should also consider consolidating `AttributeSet` as a specialized factory for an `ExpressionSet.` Author: Michael Armbrust <michael@databricks.com> Closes #11338 from marmbrus/expressionSet.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala56
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala87
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala3
7 files changed, 285 insertions, 45 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1a2ec7ed93..a12f7396fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -235,7 +235,7 @@ package object dsl {
implicit class DslAttribute(a: AttributeReference) {
def notNull: AttributeReference = a.withNullability(false)
- def nullable: AttributeReference = a.withNullability(true)
+ def canBeNull: AttributeReference = a.withNullability(true)
def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
new file mode 100644
index 0000000000..b58a527304
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Rewrites an expression using rules that are guaranteed preserve the result while attempting
+ * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
+ * will always return the same answer given the same input (i.e. false positives should not be
+ * possible). However, it is possible that two canonical expressions that are not equal will in fact
+ * return the same answer given any input (i.e. false negatives are possible).
+ *
+ * The following rules are applied:
+ * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
+ * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
+ * by `hashCode`.
+* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
+ * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
+ */
+object Canonicalize extends RuleExecutor[Expression] {
+ override protected def batches: Seq[Batch] =
+ Batch(
+ "Expression Canonicalization", FixedPoint(100),
+ IgnoreNamesTypes,
+ Reorder) :: Nil
+
+ /** Remove names and nullability from types. */
+ protected object IgnoreNamesTypes extends Rule[Expression] {
+ override def apply(e: Expression): Expression = e transformUp {
+ case a: AttributeReference =>
+ AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
+ }
+ }
+
+ /** Collects adjacent commutative operations. */
+ protected def gatherCommutative(
+ e: Expression,
+ f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
+ case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
+ case other => other :: Nil
+ }
+
+ /** Orders a set of commutative operations by their hash code. */
+ protected def orderCommutative(
+ e: Expression,
+ f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
+ gatherCommutative(e, f).sortBy(_.hashCode())
+
+ /** Rearrange expressions that are commutative or associative. */
+ protected object Reorder extends Rule[Expression] {
+ override def apply(e: Expression): Expression = e transformUp {
+ case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
+ case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
+
+ case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
+ case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)
+
+ case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
+ case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)
+
+ case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
+ case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 119496c7ee..692c16092f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -145,48 +145,31 @@ abstract class Expression extends TreeNode[Expression] {
def childrenResolved: Boolean = children.forall(_.resolved)
/**
+ * Returns an expression where a best effort attempt has been made to transform `this` in a way
+ * that preserves the result but removes cosmetic variations (case sensitivity, ordering for
+ * commutative operations, etc.) See [[Canonicalize]] for more details.
+ *
+ * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
+ * evaluate to the same result.
+ */
+ lazy val canonicalized: Expression = Canonicalize.execute(this)
+
+ /**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
+ *
+ * See [[Canonicalize]] for more details.
*/
- def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
- def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
- elements1.length == elements2.length && elements1.zip(elements2).forall {
- case (e1: Expression, e2: Expression) => e1 semanticEquals e2
- case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
- case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
- case (i1, i2) => i1 == i2
- }
- }
- // Non-deterministic expressions cannot be semantic equal
- if (!deterministic || !other.deterministic) return false
- val elements1 = this.productIterator.toSeq
- val elements2 = other.asInstanceOf[Product].productIterator.toSeq
- checkSemantic(elements1, elements2)
- }
+ def semanticEquals(other: Expression): Boolean =
+ deterministic && other.deterministic && canonicalized == other.canonicalized
/**
- * Returns the hash for this expression. Expressions that compute the same result, even if
- * they differ cosmetically should return the same hash.
+ * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
+ * `hashCode`, an attempt has been made to eliminate cosmetic differences.
+ *
+ * See [[Canonicalize]] for more details.
*/
- def semanticHash() : Int = {
- def computeHash(e: Seq[Any]): Int = {
- // See http://stackoverflow.com/questions/113511/hash-code-implementation
- var hash: Int = 17
- e.foreach(i => {
- val h: Int = i match {
- case e: Expression => e.semanticHash()
- case Some(e: Expression) => e.semanticHash()
- case t: Traversable[_] => computeHash(t.toSeq)
- case null => 0
- case other => other.hashCode()
- }
- hash = hash * 37 + h
- })
- hash
- }
-
- computeHash(this.productIterator.toSeq)
- }
+ def semanticHash(): Int = canonicalized.hashCode()
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
@@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression {
}
}
-
/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
new file mode 100644
index 0000000000..acea049adc
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+object ExpressionSet {
+ /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
+ def apply(expressions: TraversableOnce[Expression]): ExpressionSet = {
+ val set = new ExpressionSet()
+ expressions.foreach(set.add)
+ set
+ }
+}
+
+/**
+ * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]]
+ * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details.
+ *
+ * Internally this set uses the canonical representation, but keeps also track of the original
+ * expressions to ease debugging. Since different expressions can share the same canonical
+ * representation, this means that operations that extract expressions from this set are only
+ * guranteed to see at least one such expression. For example:
+ *
+ * {{{
+ * val set = AttributeSet(a + 1, 1 + a)
+ *
+ * set.iterator => Iterator(a + 1)
+ * set.contains(a + 1) => true
+ * set.contains(1 + a) => true
+ * set.contains(a + 2) => false
+ * }}}
+ */
+class ExpressionSet protected(
+ protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
+ protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
+ extends Set[Expression] {
+
+ protected def add(e: Expression): Unit = {
+ if (!baseSet.contains(e.canonicalized)) {
+ baseSet.add(e.canonicalized)
+ originals.append(e)
+ }
+ }
+
+ override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
+
+ override def +(elem: Expression): ExpressionSet = {
+ val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
+ newSet.add(elem)
+ newSet
+ }
+
+ override def -(elem: Expression): ExpressionSet = {
+ val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
+ val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
+ new ExpressionSet(newBaseSet, newOriginals)
+ }
+
+ override def iterator: Iterator[Expression] = originals.iterator
+
+ /**
+ * Returns a string containing both the post [[Canonicalize]] expressions and the original
+ * expressions in this set.
+ */
+ def toDebugString: String =
+ s"""
+ |baseSet: ${baseSet.mkString(", ")}
+ |originals: ${originals.mkString(", ")}
+ """.stripMargin
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 5e7d144ae4..a74b288cb2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
/**
- * A sequence of expressions that describes the data property of the output rows of this
- * operator. For example, if the output of this operator is column `a`, an example `constraints`
- * can be `Set(a > 10, a < 20)`.
+ * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
+ * example, if this set contains the expression `a = 2` then that expression is guaranteed to
+ * evaluate to `true` for all rows produced.
*/
- lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints)
+ lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))
/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
* based on the given operator's constraint propagation logic. These constraints are then
* canonicalized and filtered automatically to contain only those attributes that appear in the
- * [[outputSet]]
+ * [[outputSet]].
+ *
+ * See [[Canonicalize]] for more details.
*/
protected def validConstraints: Set[Expression] = Set.empty
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
new file mode 100644
index 0000000000..ce42e5784c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
+
+class ExpressionSetSuite extends SparkFunSuite {
+
+ val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
+ val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
+ val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
+
+ val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
+ val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
+
+ val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)
+
+ def setTest(size: Int, exprs: Expression*): Unit = {
+ test(s"expect $size: ${exprs.mkString(", ")}") {
+ val set = ExpressionSet(exprs)
+ if (set.size != size) {
+ fail(set.toDebugString)
+ }
+ }
+ }
+
+ def setTestIgnore(size: Int, exprs: Expression*): Unit =
+ ignore(s"expect $size: ${exprs.mkString(", ")}") {}
+
+ // Commutative
+ setTest(1, aUpper + 1, aLower + 1)
+ setTest(2, aUpper + 1, aLower + 2)
+ setTest(2, aUpper + 1, fakeA + 1)
+ setTest(2, aUpper + 1, bUpper + 1)
+
+ setTest(1, aUpper + aLower, aLower + aUpper)
+ setTest(1, aUpper + bUpper, bUpper + aUpper)
+ setTest(1,
+ aUpper + bUpper + 3,
+ bUpper + 3 + aUpper,
+ bUpper + aUpper + 3,
+ Literal(3) + aUpper + bUpper)
+ setTest(1,
+ aUpper * bUpper * 3,
+ bUpper * 3 * aUpper,
+ bUpper * aUpper * 3,
+ Literal(3) * aUpper * bUpper)
+ setTest(1, aUpper === bUpper, bUpper === aUpper)
+
+ setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper)
+
+
+ // Not commutative
+ setTest(2, aUpper - bUpper, bUpper - aUpper)
+
+ // Reversable
+ setTest(1, aUpper > bUpper, bUpper < aUpper)
+ setTest(1, aUpper >= bUpper, bUpper <= aUpper)
+
+ test("add to / remove from set") {
+ val initialSet = ExpressionSet(aUpper + 1 :: Nil)
+
+ assert((initialSet + (aUpper + 1)).size == 1)
+ assert((initialSet + (aUpper + 2)).size == 2)
+ assert((initialSet - (aUpper + 1)).size == 0)
+ assert((initialSet - (aUpper + 2)).size == 1)
+
+ assert((initialSet + (aLower + 1)).size == 1)
+ assert((initialSet - (aLower + 1)).size == 0)
+
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1d9db27e09..13ff4a2c41 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1980,9 +1980,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
verifyCallCount(
df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
- // Would be nice if semantic equals for `+` understood commutative
verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
// Try disabling it via configuration.
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")