aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala49
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala7
2 files changed, 27 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
index b58a527304..ae1f600613 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.rules._
-
/**
* Rewrites an expression using rules that are guaranteed preserve the result while attempting
* to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
@@ -30,26 +28,23 @@ import org.apache.spark.sql.catalyst.rules._
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
-* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
+ * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
*/
-object Canonicalize extends RuleExecutor[Expression] {
- override protected def batches: Seq[Batch] =
- Batch(
- "Expression Canonicalization", FixedPoint(100),
- IgnoreNamesTypes,
- Reorder) :: Nil
+object Canonicalize extends {
+ def execute(e: Expression): Expression = {
+ expressionReorder(ignoreNamesTypes(e))
+ }
/** Remove names and nullability from types. */
- protected object IgnoreNamesTypes extends Rule[Expression] {
- override def apply(e: Expression): Expression = e transformUp {
- case a: AttributeReference =>
- AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
- }
+ private def ignoreNamesTypes(e: Expression): Expression = e match {
+ case a: AttributeReference =>
+ AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
+ case _ => e
}
/** Collects adjacent commutative operations. */
- protected def gatherCommutative(
+ private def gatherCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
@@ -57,25 +52,25 @@ object Canonicalize extends RuleExecutor[Expression] {
}
/** Orders a set of commutative operations by their hash code. */
- protected def orderCommutative(
+ private def orderCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(e, f).sortBy(_.hashCode())
/** Rearrange expressions that are commutative or associative. */
- protected object Reorder extends Rule[Expression] {
- override def apply(e: Expression): Expression = e transformUp {
- case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
- case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
+ private def expressionReorder(e: Expression): Expression = e match {
+ case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
+ case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
+
+ case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
+ case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)
- case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
- case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)
+ case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
+ case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)
- case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
- case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)
+ case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
+ case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
- case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
- case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
- }
+ case _ => e
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 692c16092f..16a1b2aee2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -152,7 +152,10 @@ abstract class Expression extends TreeNode[Expression] {
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result.
*/
- lazy val canonicalized: Expression = Canonicalize.execute(this)
+ lazy val canonicalized: Expression = {
+ val canonicalizedChildred = children.map(_.canonicalized)
+ Canonicalize.execute(withNewChildren(canonicalizedChildred))
+ }
/**
* Returns true when two expressions will always compute the same result, even if they differ
@@ -161,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] {
* See [[Canonicalize]] for more details.
*/
def semanticEquals(other: Expression): Boolean =
- deterministic && other.deterministic && canonicalized == other.canonicalized
+ deterministic && other.deterministic && canonicalized == other.canonicalized
/**
* Returns a `hashCode` for the calculation performed by this expression. Unlike the standard