aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-05-07 03:37:12 -0400
committerReynold Xin <rxin@apache.org>2014-05-07 03:37:12 -0400
commit3eb53bd59e828275471d41730e6de601a887416d (patch)
treef728e59cb7eecf5e61e5bfb9d5e4672c6b6f147a /sql/catalyst
parent913a0a9c0a87e164723ebf9616b883b6329bac71 (diff)
downloadspark-3eb53bd59e828275471d41730e6de601a887416d.tar.gz
spark-3eb53bd59e828275471d41730e6de601a887416d.tar.bz2
spark-3eb53bd59e828275471d41730e6de601a887416d.zip
[WIP][Spark-SQL] Optimize the Constant Folding for Expression
Currently, expression does not support the "constant null" well in constant folding. e.g. Sum(a, 0) actually always produces Literal(0, NumericType) in runtime. For example: ``` explain select isnull(key+null) from src; == Logical Plan == Project [HiveGenericUdf#isnull((key#30 + CAST(null, IntegerType))) AS c_0#28] MetastoreRelation default, src, None == Optimized Logical Plan == Project [true AS c_0#28] MetastoreRelation default, src, None == Physical Plan == Project [true AS c_0#28] HiveTableScan [], (MetastoreRelation default, src, None), None ``` I've create a new Optimization rule called NullPropagation for such kind of constant folding. Author: Cheng Hao <hao.cheng@intel.com> Author: Michael Armbrust <michael@databricks.com> Closes #482 from chenghao-intel/optimize_constant_folding and squashes the following commits: 2f14b50 [Cheng Hao] Fix code style issues 68b9fad [Cheng Hao] Remove the Literal pattern matching for NullPropagation 29c8166 [Cheng Hao] Update the code for feedback of code review 50444cc [Cheng Hao] Remove the unnecessary null checking 80f9f18 [Cheng Hao] Update the UnitTest for aggregation constant folding 27ea3d7 [Cheng Hao] Fix Constant Folding Bugs & Add More Unittests b28e03a [Cheng Hao] Merge pull request #1 from marmbrus/pr/482 9ccefdb [Michael Armbrust] Add tests for optimized expression evaluation. 543ef9d [Cheng Hao] fix code style issues 9cf0396 [Cheng Hao] update code according to the code review comment 536c005 [Cheng Hao] Add Exceptional case for constant folding 3c045c7 [Cheng Hao] Optimize the Constant Folding by adding more rules 2645d4f [Cheng Hao] Constant Folding(null propagation)
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala115
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala36
8 files changed, 252 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 987befe8e2..dc83485df1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -114,37 +114,37 @@ package object dsl {
def attr = analysis.UnresolvedAttribute(s)
/** Creates a new AttributeReference of type boolean */
- def boolean = AttributeReference(s, BooleanType, nullable = false)()
+ def boolean = AttributeReference(s, BooleanType, nullable = true)()
/** Creates a new AttributeReference of type byte */
- def byte = AttributeReference(s, ByteType, nullable = false)()
+ def byte = AttributeReference(s, ByteType, nullable = true)()
/** Creates a new AttributeReference of type short */
- def short = AttributeReference(s, ShortType, nullable = false)()
+ def short = AttributeReference(s, ShortType, nullable = true)()
/** Creates a new AttributeReference of type int */
- def int = AttributeReference(s, IntegerType, nullable = false)()
+ def int = AttributeReference(s, IntegerType, nullable = true)()
/** Creates a new AttributeReference of type long */
- def long = AttributeReference(s, LongType, nullable = false)()
+ def long = AttributeReference(s, LongType, nullable = true)()
/** Creates a new AttributeReference of type float */
- def float = AttributeReference(s, FloatType, nullable = false)()
+ def float = AttributeReference(s, FloatType, nullable = true)()
/** Creates a new AttributeReference of type double */
- def double = AttributeReference(s, DoubleType, nullable = false)()
+ def double = AttributeReference(s, DoubleType, nullable = true)()
/** Creates a new AttributeReference of type string */
- def string = AttributeReference(s, StringType, nullable = false)()
+ def string = AttributeReference(s, StringType, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal = AttributeReference(s, DecimalType, nullable = false)()
+ def decimal = AttributeReference(s, DecimalType, nullable = true)()
/** Creates a new AttributeReference of type timestamp */
- def timestamp = AttributeReference(s, TimestampType, nullable = false)()
+ def timestamp = AttributeReference(s, TimestampType, nullable = true)()
/** Creates a new AttributeReference of type binary */
- def binary = AttributeReference(s, BinaryType, nullable = false)()
+ def binary = AttributeReference(s, BinaryType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index dd9332ada8..41398ff956 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
* child is foldable.
*/
- // TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
def foldable: Boolean = false
def nullable: Boolean
def references: Set[Attribute]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 08b2f11d20..d2b7685e73 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.trees
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -27,7 +28,10 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
-case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
+case class SortOrder(child: Expression, direction: SortDirection) extends Expression
+ with trees.UnaryNode[Expression] {
+
+ override def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index c947155cb7..195ca2eb3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -28,6 +28,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
val children = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
override def nullable = true
+ override def foldable = child.foldable && ordinal.foldable
override def references = children.flatMap(_.references).toSet
def dataType = child.dataType match {
case ArrayType(dt) => dt
@@ -40,23 +41,27 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"
override def eval(input: Row): Any = {
- if (child.dataType.isInstanceOf[ArrayType]) {
- val baseValue = child.eval(input).asInstanceOf[Seq[_]]
- val o = ordinal.eval(input).asInstanceOf[Int]
- if (baseValue == null) {
- null
- } else if (o >= baseValue.size || o < 0) {
- null
- } else {
- baseValue(o)
- }
+ val value = child.eval(input)
+ if (value == null) {
+ null
} else {
- val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
- if (baseValue == null) {
+ if (key == null) {
null
} else {
- baseValue.get(key).orNull
+ if (child.dataType.isInstanceOf[ArrayType]) {
+ val baseValue = value.asInstanceOf[Seq[_]]
+ val o = key.asInstanceOf[Int]
+ if (o >= baseValue.size || o < 0) {
+ null
+ } else {
+ baseValue(o)
+ }
+ } else {
+ val baseValue = value.asInstanceOf[Map[Any, _]]
+ val key = ordinal.eval(input)
+ baseValue.get(key).orNull
+ }
}
}
}
@@ -69,7 +74,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
type EvaluatedType = Any
def dataType = field.dataType
- def nullable = field.nullable
+ override def nullable = field.nullable
+ override def foldable = child.foldable
protected def structType = child.dataType match {
case s: StructType => s
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 82c7af6844..6ee479939d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -65,8 +65,7 @@ abstract class BinaryPredicate extends BinaryExpression with Predicate {
def nullable = left.nullable || right.nullable
}
-case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- def references = child.references
+case class Not(child: Expression) extends UnaryExpression with Predicate {
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"NOT $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c0a09a16ac..3037d45cc6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("ConstantFolding", Once,
+ NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
@@ -87,6 +88,72 @@ object ColumnPruning extends Rule[LogicalPlan] {
/**
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
+ * equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with
+ * Null value propagation from bottom to top of the expression tree.
+ */
+object NullPropagation extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
+ case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
+ case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
+ case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
+ case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
+ case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
+ case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
+ case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
+ case e @ Coalesce(children) => {
+ val newChildren = children.filter(c => c match {
+ case Literal(null, _) => false
+ case _ => true
+ })
+ if (newChildren.length == 0) {
+ Literal(null, e.dataType)
+ } else if (newChildren.length == 1) {
+ newChildren(0)
+ } else {
+ Coalesce(newChildren)
+ }
+ }
+ case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
+ case e @ In(Literal(v, _), list) if (list.exists(c => c match {
+ case Literal(candidate, _) if candidate == v => true
+ case _ => false
+ })) => Literal(true, BooleanType)
+ case e: UnaryMinus => e.child match {
+ case Literal(null, _) => Literal(null, e.dataType)
+ case _ => e
+ }
+ case e: Cast => e.child match {
+ case Literal(null, _) => Literal(null, e.dataType)
+ case _ => e
+ }
+ case e: Not => e.child match {
+ case Literal(null, _) => Literal(null, e.dataType)
+ case _ => e
+ }
+ // Put exceptional cases above if any
+ case e: BinaryArithmetic => e.children match {
+ case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
+ case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
+ case _ => e
+ }
+ case e: BinaryComparison => e.children match {
+ case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
+ case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
+ case _ => e
+ }
+ case e: StringRegexExpression => e.children match {
+ case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
+ case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
+ case _ => e
+ }
+ }
+ }
+}
+
+/**
+ * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
* equivalent [[catalyst.expressions.Literal Literal]] values.
*/
object ConstantFolding extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d287ad73b9..91605d0a26 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -108,9 +108,7 @@ class ExpressionEvaluationSuite extends FunSuite {
truthTable.foreach {
case (l,r,answer) =>
val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
- val result = expr.eval(null)
- if (result != answer)
- fail(s"$expr should not evaluate to $result, expected: $answer")
+ checkEvaluation(expr, answer)
}
}
}
@@ -131,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
+ checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
@@ -159,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
+
+ checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
}
test("RLIKE literal Regular Expression") {
+ checkEvaluation(Literal(null, StringType) rlike "abdef", null)
+ checkEvaluation("abdef" rlike Literal(null, StringType), null)
+ checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)
@@ -257,6 +261,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast DecimalType).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
+
+ checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}
test("timestamp") {
@@ -287,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
}
+
+ test("null checking") {
+ val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
+ val c1 = 'a.string.at(0)
+ val c2 = 'a.string.at(1)
+ val c3 = 'a.boolean.at(2)
+ val c4 = 'a.boolean.at(3)
+
+ checkEvaluation(IsNull(c1), false, row)
+ checkEvaluation(IsNotNull(c1), true, row)
+
+ checkEvaluation(IsNull(c2), true, row)
+ checkEvaluation(IsNotNull(c2), false, row)
+
+ checkEvaluation(IsNull(Literal(1, ShortType)), false)
+ checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
+
+ checkEvaluation(IsNull(Literal(null, ShortType)), true)
+ checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
+
+ checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
+ checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
+ checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
+
+ checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
+ checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
+ checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
+ checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
+ checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
+ checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
+ checkEvaluation(If(Literal(false, BooleanType),
+ Literal("a", StringType), Literal("b", StringType)), "b", row)
+
+ checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
+ checkEvaluation(In(Literal("^Ba*n", StringType),
+ Literal("^Ba*n", StringType) :: Nil), true, row)
+ checkEvaluation(In(Literal("^Ba*n", StringType),
+ Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
+ }
+
+ test("complex type") {
+ val row = new GenericRow(Array[Any](
+ "^Ba*n", // 0
+ null.asInstanceOf[String], // 1
+ new GenericRow(Array[Any]("aa", "bb")), // 2
+ Map("aa"->"bb"), // 3
+ Seq("aa", "bb") // 4
+ ))
+
+ val typeS = StructType(
+ StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
+ )
+ val typeMap = MapType(StringType, StringType)
+ val typeArray = ArrayType(StringType)
+
+ checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ Literal("aa")), "bb", row)
+ checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
+ checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
+ checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ Literal(null, StringType)), null, row)
+
+ checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ Literal(1)), "bb", row)
+ checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
+ checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
+ checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ Literal(null, IntegerType)), null, row)
+
+ checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
+ checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
+ }
+
+ test("arithmetic") {
+ val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val c1 = 'a.int.at(0)
+ val c2 = 'a.int.at(1)
+ val c3 = 'a.int.at(2)
+ val c4 = 'a.int.at(3)
+
+ checkEvaluation(UnaryMinus(c1), -1, row)
+ checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)
+
+ checkEvaluation(Add(c1, c4), null, row)
+ checkEvaluation(Add(c1, c2), 3, row)
+ checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
+ checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
+ checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
+ }
+
+ test("BinaryComparison") {
+ val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val c1 = 'a.int.at(0)
+ val c2 = 'a.int.at(1)
+ val c3 = 'a.int.at(2)
+ val c4 = 'a.int.at(3)
+
+ checkEvaluation(LessThan(c1, c4), null, row)
+ checkEvaluation(LessThan(c1, c2), true, row)
+ checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
+ checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
+ checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
new file mode 100644
index 0000000000..890d6289b9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+
+/**
+ * Overrides our expression evaluation tests and reruns them after optimization has occured. This
+ * is to ensure that constant folding and other optimizations do not break anything.
+ */
+class ExpressionOptimizationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation)
+ val optimizedPlan = Optimizer(plan)
+ super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow)
+ }
+} \ No newline at end of file