aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2015-06-05 23:06:19 +0800
committerCheng Lian <lian@databricks.com>2015-06-05 23:06:19 +0800
commitbc0d76a246cc534234b96a661d70feb94b26538c (patch)
tree2f8efcd3f4b773ee64c23ab4fafbaba7878101f0
parent700312e12f9588f01a592d6eac7bff7eb366ac8f (diff)
downloadspark-bc0d76a246cc534234b96a661d70feb94b26538c.tar.gz
spark-bc0d76a246cc534234b96a661d70feb94b26538c.tar.bz2
spark-bc0d76a246cc534234b96a661d70feb94b26538c.zip
[SQL] Simplifies binary node pattern matching
This PR is a simpler version of #2764, and adds `unapply` methods to the following binary nodes for simpler pattern matching: - `BinaryExpression` - `BinaryComparison` - `BinaryArithmetics` This enables nested pattern matching for binary nodes. For example, the following pattern matching ```scala case p: BinaryComparison if p.left.dataType == StringType && p.right.dataType == DateType => p.makeCopy(Array(p.left, Cast(p.right, StringType))) ``` can be simplified to ```scala case p BinaryComparison(l StringType(), r DateType()) => p.makeCopy(Array(l, Cast(r, StringType))) ``` Author: Cheng Lian <lian@databricks.com> Closes #6537 from liancheng/binary-node-patmat and squashes the following commits: a3bf5fe [Cheng Lian] Fixes compilation error introduced while rebasing b738986 [Cheng Lian] Renames `l`/`r` to `left`/`right` or `lhs`/`rhs` 14900ae [Cheng Lian] Simplifies binary node pattern matching
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala215
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala19
5 files changed, 119 insertions, 128 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index b064600e94..9b8a08a88d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -130,7 +130,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- private val stringNaN = Literal("NaN")
+ private val StringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -138,20 +138,20 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e
/* Double Conversions */
- case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
- b.makeCopy(Array(b.right, Literal(Double.NaN)))
- case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
- b.makeCopy(Array(Literal(Double.NaN), b.left))
- case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
- b.makeCopy(Array(Literal(Double.NaN), b.left))
+ case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
+ b.makeCopy(Array(Literal(Double.NaN), right))
+ case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
+ b.makeCopy(Array(left, Literal(Double.NaN)))
/* Float Conversions */
- case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
- b.makeCopy(Array(b.right, Literal(Float.NaN)))
- case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
- b.makeCopy(Array(Literal(Float.NaN), b.left))
- case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
- b.makeCopy(Array(Literal(Float.NaN), b.left))
+ case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
+ b.makeCopy(Array(Literal(Float.NaN), right))
+ case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
+ b.makeCopy(Array(left, Literal(Float.NaN)))
+
+ /* Use float NaN by default to avoid unnecessary type widening */
+ case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
+ b.makeCopy(Array(left, Literal(Float.NaN)))
}
}
}
@@ -184,21 +184,25 @@ trait HiveTypeCoercion {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
- case (l, r) if l.dataType == StringType && r.dataType != StringType =>
- (l, Alias(Cast(r, StringType), r.name)())
- case (l, r) if l.dataType != StringType && r.dataType == StringType =>
- (Alias(Cast(l, StringType), l.name)(), r)
-
- case (l, r) if l.dataType != r.dataType =>
- logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
- findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
+ case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
+ (lhs, Alias(Cast(rhs, StringType), rhs.name)())
+ case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
+ (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
+
+ case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+ logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}")
+ findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
val newLeft =
- if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
+ if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
val newRight =
- if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)()
+ if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
(newLeft, newRight)
- }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged.
+ }.getOrElse {
+ // If there is no applicable conversion, leave expression unchanged.
+ (lhs, rhs)
+ }
+
case other => other
}
@@ -227,12 +231,10 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case b: BinaryExpression if b.left.dataType != b.right.dataType =>
- findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
- val newLeft =
- if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
- val newRight =
- if (b.right.dataType == widestType) b.right else Cast(b.right, widestType)
+ case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
+ findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
+ val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+ val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
b.makeCopy(Array(newLeft, newRight))
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
}
@@ -247,57 +249,42 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a: BinaryArithmetic if a.left.dataType == StringType =>
- a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
- case a: BinaryArithmetic if a.right.dataType == StringType =>
- a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
+ case a @ BinaryArithmetic(left @ StringType(), r) =>
+ a.makeCopy(Array(Cast(left, DoubleType), r))
+ case a @ BinaryArithmetic(left, right @ StringType()) =>
+ a.makeCopy(Array(left, Cast(right, DoubleType)))
// we should cast all timestamp/date/string compare into string compare
- case p: BinaryComparison if p.left.dataType == StringType &&
- p.right.dataType == DateType =>
- p.makeCopy(Array(p.left, Cast(p.right, StringType)))
- case p: BinaryComparison if p.left.dataType == DateType &&
- p.right.dataType == StringType =>
- p.makeCopy(Array(Cast(p.left, StringType), p.right))
- case p: BinaryComparison if p.left.dataType == StringType &&
- p.right.dataType == TimestampType =>
- p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
- case p: BinaryComparison if p.left.dataType == TimestampType &&
- p.right.dataType == StringType =>
- p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
- case p: BinaryComparison if p.left.dataType == TimestampType &&
- p.right.dataType == DateType =>
- p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
- case p: BinaryComparison if p.left.dataType == DateType &&
- p.right.dataType == TimestampType =>
- p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
-
- case p: BinaryComparison if p.left.dataType == StringType &&
- p.right.dataType != StringType =>
- p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
- case p: BinaryComparison if p.left.dataType != StringType &&
- p.right.dataType == StringType =>
- p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
-
- case i @ In(a, b) if a.dataType == DateType &&
- b.forall(_.dataType == StringType) =>
+ case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
+ p.makeCopy(Array(left, Cast(right, StringType)))
+ case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
+ p.makeCopy(Array(Cast(left, StringType), right))
+ case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
+ p.makeCopy(Array(Cast(left, TimestampType), right))
+ case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
+ p.makeCopy(Array(left, Cast(right, TimestampType)))
+ case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
+ p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
+ case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
+ p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
+
+ case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType =>
+ p.makeCopy(Array(Cast(left, DoubleType), right))
+ case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType =>
+ p.makeCopy(Array(left, Cast(right, DoubleType)))
+
+ case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
- case i @ In(a, b) if a.dataType == TimestampType &&
- b.forall(_.dataType == StringType) =>
+ case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a, b.map(Cast(_, TimestampType))))
- case i @ In(a, b) if a.dataType == DateType &&
- b.forall(_.dataType == TimestampType) =>
+ case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
- case i @ In(a, b) if a.dataType == TimestampType &&
- b.forall(_.dataType == DateType) =>
+ case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
- case Sum(e) if e.dataType == StringType =>
- Sum(Cast(e, DoubleType))
- case Average(e) if e.dataType == StringType =>
- Average(Cast(e, DoubleType))
- case Sqrt(e) if e.dataType == StringType =>
- Sqrt(Cast(e, DoubleType))
+ case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
+ case Average(e @ StringType()) => Average(Cast(e, DoubleType))
+ case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
}
}
@@ -379,22 +366,22 @@ trait HiveTypeCoercion {
// fix decimal precision for union
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
- case (l, r) if l.dataType != r.dataType =>
- (l.dataType, r.dataType) match {
+ case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+ (lhs.dataType, rhs.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
// Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
- (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)())
+ (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
+ (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- (l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
+ (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
- (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
+ (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
- (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
- case _ => (l, r)
+ (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
+ case _ => (lhs, rhs)
}
case other => other
}
@@ -467,16 +454,16 @@ trait HiveTypeCoercion {
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
- case b: BinaryExpression if b.left.dataType != b.right.dataType =>
- (b.left.dataType, b.right.dataType) match {
+ case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
+ (left.dataType, right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+ b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
- b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+ b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
- b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+ b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
- b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+ b.makeCopy(Array(Cast(left, DoubleType), right))
case _ =>
b
}
@@ -525,31 +512,31 @@ trait HiveTypeCoercion {
// all other cases are considered as false.
// We may simplify the expression if one side is literal numeric values
- case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
- if trueValues.contains(value) => l
- case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
- if falseValues.contains(value) => Not(l)
- case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
- if trueValues.contains(value) => r
- case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
- if falseValues.contains(value) => Not(r)
- case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
- if trueValues.contains(value) => And(IsNotNull(l), l)
- case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
- if falseValues.contains(value) => And(IsNotNull(l), Not(l))
- case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
- if trueValues.contains(value) => And(IsNotNull(r), r)
- case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
- if falseValues.contains(value) => And(IsNotNull(r), Not(r))
-
- case EqualTo(l @ BooleanType(), r @ NumericType()) =>
- transform(l , r)
- case EqualTo(l @ NumericType(), r @ BooleanType()) =>
- transform(r, l)
- case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
- transformNullSafe(l, r)
- case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
- transformNullSafe(r, l)
+ case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => left
+ case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => Not(left)
+ case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
+ if trueValues.contains(value) => right
+ case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
+ if falseValues.contains(value) => Not(right)
+ case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => And(IsNotNull(left), left)
+ case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => And(IsNotNull(left), Not(left))
+ case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
+ if trueValues.contains(value) => And(IsNotNull(right), right)
+ case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
+ if falseValues.contains(value) => And(IsNotNull(right), Not(right))
+
+ case EqualTo(left @ BooleanType(), right @ NumericType()) =>
+ transform(left , right)
+ case EqualTo(left @ NumericType(), right @ BooleanType()) =>
+ transform(right, left)
+ case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
+ transformNullSafe(left, right)
+ case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
+ transformNullSafe(right, left)
}
}
@@ -630,7 +617,7 @@ trait HiveTypeCoercion {
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
- case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+ case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 3cf851aec1..b2b9d1a5e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -118,6 +118,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def toString: String = s"($left $symbol $right)"
}
+private[sql] object BinaryExpression {
+ def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
+}
+
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
self: Product =>
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 2ac53f8f66..a3770f998d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -118,6 +118,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}
+private[sql] object BinaryArithmetic {
+ def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
+}
+
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 807021d50e..58273b166f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -202,9 +202,8 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
sys.error(s"BinaryComparisons must override either eval or evalInternal")
}
-object BinaryComparison {
- def unapply(b: BinaryComparison): Option[(Expression, Expression)] =
- Some((b.left, b.right))
+private[sql] object BinaryComparison {
+ def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right))
}
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0a17b10c52..c16f08d389 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -266,7 +266,7 @@ object NullPropagation extends Rule[LogicalPlan] {
if (newChildren.length == 0) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
- newChildren(0)
+ newChildren.head
} else {
Coalesce(newChildren)
}
@@ -280,21 +280,18 @@ object NullPropagation extends Rule[LogicalPlan] {
case e: MinOf => e
// Put exceptional cases above if any
- case e: BinaryArithmetic => e.children match {
- case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
- case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
- case _ => e
- }
- case e: BinaryComparison => e.children match {
- case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
- case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
- case _ => e
- }
+ case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
+
+ case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
+ case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
+
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}
+
case e: StringComparison => e.children match {
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)