aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-11 12:57:33 -0700
committerReynold Xin <rxin@databricks.com>2015-06-11 12:57:33 -0700
commit1191c3efc605d9c6d1df4b38ddae8d210a361b5b (patch)
tree82aa65b93da8666af66ef36f0909402c12a842f3
parent424b0075a1a31c251451c6a75c6ba8e81c39453d (diff)
downloadspark-1191c3efc605d9c6d1df4b38ddae8d210a361b5b.tar.gz
spark-1191c3efc605d9c6d1df4b38ddae8d210a361b5b.tar.bz2
spark-1191c3efc605d9c6d1df4b38ddae8d210a361b5b.zip
[SPARK-8305] [SPARK-8190] [SQL] improve codegen
This PR fix a few small issues about codgen: 1. cast decimal to boolean 2. do not inline literal with null 3. improve SpecificRow.equals() 4. test expressions with optimized express 5. fix compare with BinaryType cc rxin chenghao-intel Author: Davies Liu <davies@databricks.com> Closes #6755 from davies/fix_codegen and squashes the following commits: ef27343 [Davies Liu] address comments 6617ea6 [Davies Liu] fix scala tyle 70b7dda [Davies Liu] improve codegen
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala37
14 files changed, 141 insertions, 129 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
index 6584882a62..e91daf17f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
@@ -154,6 +154,27 @@ public abstract class BaseRow implements Row {
throw new UnsupportedOperationException();
}
+ /**
+ * A generic version of Row.equals(Row), which is used for tests.
+ */
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof Row) {
+ Row row = (Row) other;
+ int n = size();
+ if (n != row.size()) {
+ return false;
+ }
+ for (int i = 0; i < n; i ++) {
+ if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
@Override
public Row copy() {
final int n = size();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8d93957fea..037efd7558 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
- buildCast[Decimal](_, _ != 0)
+ buildCast[Decimal](_, _ != Decimal(0))
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -454,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
- defineCodeGen(ctx, ev, c => s"$c.isZero()")
+ defineCodeGen(ctx, ev, c => s"!$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 80aa8fa056..ecf8e0d1a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -161,15 +161,23 @@ class CodeGenContext {
}
/**
- * Returns a function to generate equal expression in Java
+ * Generate code for equal expression in Java
*/
- def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
- case BinaryType => { case (eval1, eval2) =>
- s"java.util.Arrays.equals($eval1, $eval2)" }
- case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
- { case (eval1, eval2) => s"$eval1 == $eval2" }
- case other =>
- { case (eval1, eval2) => s"$eval1.equals($eval2)" }
+ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
+ case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
+ case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+ case other => s"$c1.equals($c2)"
+ }
+
+ /**
+ * Generate code for compare expression in Java
+ */
+ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
+ // Use signum() to keep any small difference bwteen float/double
+ case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
+ case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
+ case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
+ case other => s"$c1.compare($c2)"
}
/**
@@ -182,6 +190,16 @@ class CodeGenContext {
* Returns true if the data type has a special accessor and setter in [[Row]].
*/
def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
+
+ /**
+ * List of data types who's Java type is primitive type
+ */
+ val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)
+
+ /**
+ * Returns true if the Java type is primitive type
+ */
+ def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e5ee2accd8..ed3df547d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -82,7 +82,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
}
"""
-
logDebug(s"code for ${expressions.mkString(",")}:\n$code")
val c = compile(code)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 36e155d164..56ecc5fc06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Private
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{BinaryType, NumericType}
/**
* Inherits some default implementation for Java from `Ordering[Row]`
@@ -55,39 +54,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
val evalA = order.child.gen(ctx)
val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
- val compare = order.child.dataType match {
- case BinaryType =>
- s"""
- {
- byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
- byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
- int j = 0;
- while (j < x.length && j < y.length) {
- if (x[j] != y[j]) return x[j] - y[j];
- j = j + 1;
- }
- int d = x.length - y.length;
- if (d != 0) {
- return d;
- }
- }"""
- case _: NumericType =>
- s"""
- if (${evalA.primitive} != ${evalB.primitive}) {
- if (${evalA.primitive} > ${evalB.primitive}) {
- return ${if (asc) "1" else "-1"};
- } else {
- return ${if (asc) "-1" else "1"};
- }
- }"""
- case _ =>
- s"""
- int comp = ${evalA.primitive}.compare(${evalB.primitive});
- if (comp != 0) {
- return ${if (asc) "comp" else "-comp"};
- }"""
- }
-
s"""
i = $a;
${evalA.code}
@@ -100,7 +66,10 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
- $compare
+ int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
+ if (comp != 0) {
+ return ${if (asc) "comp" else "-comp"};
+ }
}
"""
}.mkString("\n")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 274429cd1c..9b906c3ff5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -72,14 +72,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")
val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
- val cases = expressions.zipWithIndex.map {
- case (e, i) if e.dataType == dataType
- || dataType == IntegerType && e.dataType == DateType
- || dataType == LongType && e.dataType == TimestampType =>
- s"case $i: return c$i;"
- case _ => ""
+ val cases = expressions.zipWithIndex.flatMap {
+ case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
+ List(s"case $i: return c$i;")
+ case _ => Nil
}.mkString("\n ")
- if (cases.count(_ != '\n') > 0) {
+ if (cases.length > 0) {
s"""
@Override
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
@@ -89,7 +87,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
switch (i) {
$cases
}
- return ${ctx.defaultValue(dataType)};
+ throw new IllegalArgumentException("Invalid index: " + i
+ + " in ${ctx.accessorForType(dataType)}");
}"""
} else {
""
@@ -97,14 +96,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n")
val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
- val cases = expressions.zipWithIndex.map {
- case (e, i) if e.dataType == dataType
- || dataType == IntegerType && e.dataType == DateType
- || dataType == LongType && e.dataType == TimestampType =>
- s"case $i: { c$i = value; return; }"
- case _ => ""
- }.mkString("\n")
- if (cases.count(_ != '\n') > 0) {
+ val cases = expressions.zipWithIndex.flatMap {
+ case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
+ List(s"case $i: { c$i = value; return; }")
+ case _ => Nil
+ }.mkString("\n ")
+ if (cases.length > 0) {
s"""
@Override
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
@@ -112,6 +109,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
switch (i) {
$cases
}
+ throw new IllegalArgumentException("Invalid index: " + i +
+ " in ${ctx.mutatorForType(dataType)}");
}"""
} else {
""
@@ -139,9 +138,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
s"""
- if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
- return false;
- }
+ if (nullBits[$i] != row.nullBits[$i] ||
+ (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) {
+ return false;
+ }
"""
}.mkString("\n")
@@ -174,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
public int size() { return ${expressions.length};}
- private boolean[] nullBits = new boolean[${expressions.length}];
+ protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
@@ -207,9 +207,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
@Override
public boolean equals(Object other) {
- if (other instanceof Row) {
- Row row = (Row) other;
- if (row.length() != size()) return false;
+ if (other instanceof SpecificRow) {
+ SpecificRow row = (SpecificRow) other;
$columnChecks
return true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 1a5cde26c9..72b9f23456 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
${cond.code}
if (${keyEval.isNull} && ${cond.isNull} ||
!${keyEval.isNull} && !${cond.isNull}
- && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
+ && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 833c08a293..ef50c50e13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -92,8 +92,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
// change the isNull and primitive to consts, to inline them
if (value == null) {
ev.isNull = "true"
- ev.primitive = ctx.defaultValue(dataType)
- ""
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};"
} else {
dataType match {
case BooleanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2c49352874..7574d1cbda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- left.dataType match {
- case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
- (c1, c3) => s"$c1 $symbol $c3"
- })
- case DateType | TimestampType => defineCodeGen (ctx, ev, {
- (c1, c3) => s"$c1 $symbol $c3"
- })
- case other => defineCodeGen (ctx, ev, {
- (c1, c2) => s"$c1.compare($c2) $symbol 0"
- })
+ if (ctx.isPrimitiveType(left.dataType)) {
+ // faster version
+ defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
+ } else {
+ defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}
@@ -280,8 +275,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
+ defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
}
}
@@ -307,7 +303,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
- val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
+ val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive)
ev.isNull = "false"
eval1.code + eval2.code + s"""
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 0bb12d2039..04857a23f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -53,4 +53,12 @@ object TypeUtils {
def getOrdering(t: DataType): Ordering[Any] =
t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]
+
+ def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
+ for (i <- 0 until x.length; if i < y.length) {
+ val res = x(i).compareTo(y(i))
+ if (res != 0) return res
+ }
+ x.length - y.length
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index a581a9e946..9b58601e5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.util.TypeUtils
/**
@@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType {
private[sql] val ordering = new Ordering[InternalType] {
def compare(x: Array[Byte], y: Array[Byte]): Int = {
- for (i <- 0 until x.length; if i < y.length) {
- val res = x(i).compareTo(y(i))
- if (res != 0) return res
- }
- x.length - y.length
+ TypeUtils.compareBinary(x, y)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 3aca94db3b..969c6cc15f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -43,7 +43,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from int") {
checkCast(0, false)
checkCast(1, true)
- checkCast(5, true)
+ checkCast(-5, true)
checkCast(1, 1.toByte)
checkCast(1, 1.toShort)
checkCast(1, 1)
@@ -61,7 +61,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from long") {
checkCast(0L, false)
checkCast(1L, true)
- checkCast(5L, true)
+ checkCast(-5L, true)
checkCast(1L, 1.toByte)
checkCast(1L, 1.toShort)
checkCast(1L, 1)
@@ -99,10 +99,28 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from float") {
-
+ checkCast(0.0f, false)
+ checkCast(0.5f, true)
+ checkCast(-5.0f, true)
+ checkCast(1.5f, 1.toByte)
+ checkCast(1.5f, 1.toShort)
+ checkCast(1.5f, 1)
+ checkCast(1.5f, 1.toLong)
+ checkCast(1.5f, 1.5)
+ checkCast(1.5f, "1.5")
}
test("cast from double") {
+ checkCast(0.0, false)
+ checkCast(0.5, true)
+ checkCast(-5.0, true)
+ checkCast(1.5, 1.toByte)
+ checkCast(1.5, 1.toShort)
+ checkCast(1.5, 1)
+ checkCast(1.5, 1.toLong)
+ checkCast(1.5, 1.5f)
+ checkCast(1.5, "1.5")
+
checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
}
@@ -183,6 +201,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
}
+ test("from decimal") {
+ checkCast(Decimal(0.0), false)
+ checkCast(Decimal(0.5), true)
+ checkCast(Decimal(-5.0), true)
+ checkCast(Decimal(1.5), 1.toByte)
+ checkCast(Decimal(1.5), 1.toShort)
+ checkCast(Decimal(1.5), 1)
+ checkCast(Decimal(1.5), 1.toLong)
+ checkCast(Decimal(1.5), 1.5f)
+ checkCast(Decimal(1.5), 1.5)
+ checkCast(Decimal(1.5), "1.5")
+ }
+
test("casting to fixed-precision decimals") {
// Overflow and rounding for casting to fixed-precision decimals:
// - Values should round with HALF_UP mode by default when you lower scale
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 87a92b8796..4a241d3603 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -23,6 +23,8 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
+import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
/**
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
@@ -39,6 +41,7 @@ trait ExpressionEvalHelper {
checkEvaluationWithoutCodegen(expression, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
+ checkEvaluationWithOptimization(expression, expected, inputRow)
}
protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
@@ -122,6 +125,15 @@ trait ExpressionEvalHelper {
}
}
+ protected def checkEvaluationWithOptimization(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
+ val optimizedPlan = DefaultOptimizer.execute(plan)
+ checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
+ }
+
protected def checkDoubleEvaluation(
expression: Expression,
expected: Spread[Double],
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
deleted file mode 100644
index f33a18d53b..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-
-/**
- * Overrides our expression evaluation tests and reruns them after optimization has occured. This
- * is to ensure that constant folding and other optimizations do not break anything.
- */
-class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper {
- override def checkEvaluation(
- expression: Expression,
- expected: Any,
- inputRow: Row = EmptyRow): Unit = {
- val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer.execute(plan)
- super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow)
- }
-}