diff options
14 files changed, 141 insertions, 129 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java index 6584882a62..e91daf17f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -154,6 +154,27 @@ public abstract class BaseRow implements Row { throw new UnsupportedOperationException(); } + /** + * A generic version of Row.equals(Row), which is used for tests. + */ + @Override + public boolean equals(Object other) { + if (other instanceof Row) { + Row row = (Row) other; + int n = size(); + if (n != row.size()) { + return false; + } + for (int i = 0; i < n; i ++) { + if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) { + return false; + } + } + return true; + } + return false; + } + @Override public Row copy() { final int n = size(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d93957fea..037efd7558 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != 0) + buildCast[Decimal](_, _ != Decimal(0)) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -454,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c.isZero()") + defineCodeGen(ctx, ev, c => s"!$c.isZero()") case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 80aa8fa056..ecf8e0d1a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -161,15 +161,23 @@ class CodeGenContext { } /** - * Returns a function to generate equal expression in Java + * Generate code for equal expression in Java */ - def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { - case BinaryType => { case (eval1, eval2) => - s"java.util.Arrays.equals($eval1, $eval2)" } - case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => - { case (eval1, eval2) => s"$eval1 == $eval2" } - case other => - { case (eval1, eval2) => s"$eval1.equals($eval2)" } + def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { + case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case other => s"$c1.equals($c2)" + } + + /** + * Generate code for compare expression in Java + */ + def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // Use signum() to keep any small difference bwteen float/double + case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)" + case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)" + case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case other => s"$c1.compare($c2)" } /** @@ -182,6 +190,16 @@ class CodeGenContext { * Returns true if the data type has a special accessor and setter in [[Row]]. */ def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) + + /** + * List of data types who's Java type is primitive type + */ + val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType) + + /** + * Returns true if the Java type is primitive type + */ + def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e5ee2accd8..ed3df547d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") val c = compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 36e155d164..56ecc5fc06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Private import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, NumericType} /** * Inherits some default implementation for Java from `Ordering[Row]` @@ -55,39 +54,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val evalA = order.child.gen(ctx) val evalB = order.child.gen(ctx) val asc = order.direction == Ascending - val compare = order.child.dataType match { - case BinaryType => - s""" - { - byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; - byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; - int j = 0; - while (j < x.length && j < y.length) { - if (x[j] != y[j]) return x[j] - y[j]; - j = j + 1; - } - int d = x.length - y.length; - if (d != 0) { - return d; - } - }""" - case _: NumericType => - s""" - if (${evalA.primitive} != ${evalB.primitive}) { - if (${evalA.primitive} > ${evalB.primitive}) { - return ${if (asc) "1" else "-1"}; - } else { - return ${if (asc) "-1" else "1"}; - } - }""" - case _ => - s""" - int comp = ${evalA.primitive}.compare(${evalB.primitive}); - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - }""" - } - s""" i = $a; ${evalA.code} @@ -100,7 +66,10 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - $compare + int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + } } """ }.mkString("\n") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 274429cd1c..9b906c3ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -72,14 +72,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val specificAccessorFunctions = ctx.nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType - || dataType == IntegerType && e.dataType == DateType - || dataType == LongType && e.dataType == TimestampType => - s"case $i: return c$i;" - case _ => "" + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => + List(s"case $i: return c$i;") + case _ => Nil }.mkString("\n ") - if (cases.count(_ != '\n') > 0) { + if (cases.length > 0) { s""" @Override public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { @@ -89,7 +87,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { switch (i) { $cases } - return ${ctx.defaultValue(dataType)}; + throw new IllegalArgumentException("Invalid index: " + i + + " in ${ctx.accessorForType(dataType)}"); }""" } else { "" @@ -97,14 +96,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n") val specificMutatorFunctions = ctx.nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType - || dataType == IntegerType && e.dataType == DateType - || dataType == LongType && e.dataType == TimestampType => - s"case $i: { c$i = value; return; }" - case _ => "" - }.mkString("\n") - if (cases.count(_ != '\n') > 0) { + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => + List(s"case $i: { c$i = value; return; }") + case _ => Nil + }.mkString("\n ") + if (cases.length > 0) { s""" @Override public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { @@ -112,6 +109,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { switch (i) { $cases } + throw new IllegalArgumentException("Invalid index: " + i + + " in ${ctx.mutatorForType(dataType)}"); }""" } else { "" @@ -139,9 +138,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val columnChecks = expressions.zipWithIndex.map { case (e, i) => s""" - if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { - return false; - } + if (nullBits[$i] != row.nullBits[$i] || + (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { + return false; + } """ }.mkString("\n") @@ -174,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public int size() { return ${expressions.length};} - private boolean[] nullBits = new boolean[${expressions.length}]; + protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } @@ -207,9 +207,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - if (row.length() != size()) return false; + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; $columnChecks return true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 1a5cde26c9..72b9f23456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW ${cond.code} if (${keyEval.isNull} && ${cond.isNull} || !${keyEval.isNull} && !${cond.isNull} - && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 833c08a293..ef50c50e13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -92,8 +92,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - ev.primitive = ctx.defaultValue(dataType) - "" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};" } else { dataType match { case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2c49352874..7574d1cbda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { - (c1, c3) => s"$c1 $symbol $c3" - }) - case DateType | TimestampType => defineCodeGen (ctx, ev, { - (c1, c3) => s"$c1 $symbol $c3" - }) - case other => defineCodeGen (ctx, ev, { - (c1, c2) => s"$c1.compare($c2) $symbol 0" - }) + if (ctx.isPrimitiveType(left.dataType)) { + // faster version + defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } @@ -280,8 +275,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) + defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -307,7 +303,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive) ev.isNull = "false" eval1.code + eval2.code + s""" boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0bb12d2039..04857a23f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -53,4 +53,12 @@ object TypeUtils { def getOrdering(t: DataType): Ordering[Any] = t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] + + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index a581a9e946..9b58601e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.util.TypeUtils /** @@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType { private[sql] val ordering = new Ordering[InternalType] { def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length + TypeUtils.compareBinary(x, y) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 3aca94db3b..969c6cc15f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -43,7 +43,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int") { checkCast(0, false) checkCast(1, true) - checkCast(5, true) + checkCast(-5, true) checkCast(1, 1.toByte) checkCast(1, 1.toShort) checkCast(1, 1) @@ -61,7 +61,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from long") { checkCast(0L, false) checkCast(1L, true) - checkCast(5L, true) + checkCast(-5L, true) checkCast(1L, 1.toByte) checkCast(1L, 1.toShort) checkCast(1L, 1) @@ -99,10 +99,28 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from float") { - + checkCast(0.0f, false) + checkCast(0.5f, true) + checkCast(-5.0f, true) + checkCast(1.5f, 1.toByte) + checkCast(1.5f, 1.toShort) + checkCast(1.5f, 1) + checkCast(1.5f, 1.toLong) + checkCast(1.5f, 1.5) + checkCast(1.5f, "1.5") } test("cast from double") { + checkCast(0.0, false) + checkCast(0.5, true) + checkCast(-5.0, true) + checkCast(1.5, 1.toByte) + checkCast(1.5, 1.toShort) + checkCast(1.5, 1) + checkCast(1.5, 1.toLong) + checkCast(1.5, 1.5f) + checkCast(1.5, "1.5") + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } @@ -183,6 +201,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) } + test("from decimal") { + checkCast(Decimal(0.0), false) + checkCast(Decimal(0.5), true) + checkCast(Decimal(-5.0), true) + checkCast(Decimal(1.5), 1.toByte) + checkCast(Decimal(1.5), 1.toShort) + checkCast(Decimal(1.5), 1) + checkCast(Decimal(1.5), 1.toLong) + checkCast(Decimal(1.5), 1.5f) + checkCast(Decimal(1.5), 1.5) + checkCast(Decimal(1.5), "1.5") + } + test("casting to fixed-precision decimals") { // Overflow and rounding for casting to fixed-precision decimals: // - Values should round with HALF_UP mode by default when you lower scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 87a92b8796..4a241d3603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,6 +23,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. @@ -39,6 +41,7 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, expected, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + checkEvaluationWithOptimization(expression, expected, inputRow) } protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { @@ -122,6 +125,15 @@ trait ExpressionEvalHelper { } } + protected def checkEvaluationWithOptimization( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) + } + protected def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala deleted file mode 100644 index f33a18d53b..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ - -/** - * Overrides our expression evaluation tests and reruns them after optimization has occured. This - * is to ensure that constant folding and other optimizations do not break anything. - */ -class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) - super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) - } -} |