aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-23 23:11:42 -0700
committerDavies Liu <davies@databricks.com>2015-06-23 23:11:42 -0700
commit09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb (patch)
tree3b7a32a150313ae7c5b7c0860541de2e614132f2 /sql
parent13ae806b255cfb2bd5470b599a95c87a2cd5e978 (diff)
downloadspark-09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb.tar.gz
spark-09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb.tar.bz2
spark-09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb.zip
[SPARK-8371] [SQL] improve unit test for MaxOf and MinOf and fix bugs
a follow up of https://github.com/apache/spark/pull/6813 Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6825 from cloud-fan/cg and squashes the following commits: 43170cc [Wenchen Fan] fix bugs in code gen
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala46
2 files changed, 34 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bd5475d206..47c5455435 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -175,8 +175,10 @@ class CodeGenContext {
* Generate code for compare expression in Java
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
+ // java boolean doesn't support > or < operator
+ case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
// use c1 - c2 may overflow
- case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
+ case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case other => s"$c1.compare($c2)"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 4bbbbe6c7f..6c93698f80 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
+import org.apache.spark.sql.types.Decimal
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}
- test("MaxOf") {
- checkEvaluation(MaxOf(1, 2), 2)
- checkEvaluation(MaxOf(2, 1), 2)
- checkEvaluation(MaxOf(1L, 2L), 2L)
- checkEvaluation(MaxOf(2L, 1L), 2L)
+ test("MaxOf basic") {
+ testNumericDataTypes { convert =>
+ val small = Literal(convert(1))
+ val large = Literal(convert(2))
+ checkEvaluation(MaxOf(small, large), convert(2))
+ checkEvaluation(MaxOf(large, small), convert(2))
+ checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2))
+ checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2))
+ }
+ }
- checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2)
- checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
+ test("MaxOf for atomic type") {
+ checkEvaluation(MaxOf(true, false), true)
+ checkEvaluation(MaxOf("abc", "bcd"), "bcd")
+ checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+ Array(1.toByte, 3.toByte))
}
- test("MinOf") {
- checkEvaluation(MinOf(1, 2), 1)
- checkEvaluation(MinOf(2, 1), 1)
- checkEvaluation(MinOf(1L, 2L), 1L)
- checkEvaluation(MinOf(2L, 1L), 1L)
+ test("MinOf basic") {
+ testNumericDataTypes { convert =>
+ val small = Literal(convert(1))
+ val large = Literal(convert(2))
+ checkEvaluation(MinOf(small, large), convert(1))
+ checkEvaluation(MinOf(large, small), convert(1))
+ checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2))
+ checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1))
+ }
+ }
- checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
- checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+ test("MinOf for atomic type") {
+ checkEvaluation(MinOf(true, false), false)
+ checkEvaluation(MinOf("abc", "bcd"), "abc")
+ checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
+ Array(1.toByte, 2.toByte))
}
}