aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-08-03 00:15:24 -0700
committerReynold Xin <rxin@databricks.com>2015-08-03 00:15:24 -0700
commit98d6d9c7a996f5456eb2653bb96985a1a05f4ce1 (patch)
treef773995dce2ae3ce065b05a0976fa2c98b732e43 /sql
parent608353c8e8e50461fafff91a2c885dca8af3aaa8 (diff)
downloadspark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.tar.gz
spark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.tar.bz2
spark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.zip
[SPARK-9549][SQL] fix bugs in expressions
JIRA: https://issues.apache.org/jira/browse/SPARK-9549 This PR fix the following bugs: 1. `UnaryMinus`'s codegen version would fail to compile when the input is `Long.MinValue` 2. `BinaryComparison` would fail to compile in codegen mode when comparing Boolean types. 3. `AddMonth` would fail if passed a huge negative month, which would lead accessing negative index of `monthDays` array. 4. `Nanvl` with different type operands. Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7882 from yjshen/minor_bug_fix and squashes the following commits: 41bbd2c [Yijie Shen] fix bug in Nanvl type coercion 3dee204 [Yijie Shen] address comments 4fa5de0 [Yijie Shen] fix bugs in expressions
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala18
9 files changed, 79 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 603afc4032..422d423747 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -562,6 +562,11 @@ object HiveTypeCoercion {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
+
+ case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
+ NaNvl(l, Cast(r, DoubleType))
+ case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
+ NaNvl(Cast(l, DoubleType), r)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 6f8f4dd230..0891b55494 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
- case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
+ case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
+ val originValue = ctx.freshName("origin")
+ // codegen would fail to compile if we just write (-($c))
+ // for example, we could not write --9223372036854775808L in code
+ s"""
+ ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
+ ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue));
+ """})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ab7d3afce8..b69bbabee7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)
+ && left.dataType != BooleanType // java boolean doesn't support > or < operator
&& left.dataType != FloatType
&& left.dataType != DoubleType) {
// faster version
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 6a98f4d9c5..f645eb5f7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -614,8 +614,9 @@ object DateTimeUtils {
*/
def dateAddMonths(days: Int, months: Int): Int = {
val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
- val currentMonthInYear = absoluteMonth % 12
- val currentYear = absoluteMonth / 12
+ val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
+ val currentMonthInYear = nonNegativeMonth % 12
+ val currentYear = nonNegativeMonth / 12
val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
@@ -626,7 +627,7 @@ object DateTimeUtils {
} else {
dayOfMonth
}
- firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
+ firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 70608771dd..cbdf453f60 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
+ test("nanvl casts") {
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
+ NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
+ NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+ }
+
test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
ruleTest(rule,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index d03b0fbbfb..0bae8fe2fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.types._
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,6 +56,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(input), convert(-1))
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
}
+ checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
}
test("- (Minus)") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 3bff8e012a..e6e8790e90 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null)
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
+ checkEvaluation(
+ AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498)
}
test("months_between") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 0bc2812a5d..d7eb13c50b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -136,60 +136,60 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}
- private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
+ private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
private val largeValues =
- Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
private val equalValues1 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
private val equalValues2 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
- test("BinaryComparison: <") {
+ test("BinaryComparison: lessThan") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) < largeValues(i), true)
- checkEvaluation(equalValues1(i) < equalValues2(i), false)
- checkEvaluation(largeValues(i) < smallValues(i), false)
+ checkEvaluation(LessThan(smallValues(i), largeValues(i)), true)
+ checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false)
+ checkEvaluation(LessThan(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: <=") {
+ test("BinaryComparison: LessThanOrEqual") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) <= largeValues(i), true)
- checkEvaluation(equalValues1(i) <= equalValues2(i), true)
- checkEvaluation(largeValues(i) <= smallValues(i), false)
+ checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true)
+ checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: >") {
+ test("BinaryComparison: GreaterThan") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) > largeValues(i), false)
- checkEvaluation(equalValues1(i) > equalValues2(i), false)
- checkEvaluation(largeValues(i) > smallValues(i), true)
+ checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false)
+ checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false)
+ checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true)
}
}
- test("BinaryComparison: >=") {
+ test("BinaryComparison: GreaterThanOrEqual") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) >= largeValues(i), false)
- checkEvaluation(equalValues1(i) >= equalValues2(i), true)
- checkEvaluation(largeValues(i) >= smallValues(i), true)
+ checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), false)
+ checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true)
}
}
- test("BinaryComparison: ===") {
+ test("BinaryComparison: EqualTo") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) === largeValues(i), false)
- checkEvaluation(equalValues1(i) === equalValues2(i), true)
- checkEvaluation(largeValues(i) === smallValues(i), false)
+ checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false)
+ checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: <=>") {
+ test("BinaryComparison: EqualNullSafe") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) <=> largeValues(i), false)
- checkEvaluation(equalValues1(i) <=> equalValues2(i), true)
- checkEvaluation(largeValues(i) <=> smallValues(i), false)
+ checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false)
+ checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false)
}
}
@@ -209,8 +209,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
nullTest(GreaterThanOrEqual)
nullTest(EqualTo)
- checkEvaluation(normalInt <=> nullInt, false)
- checkEvaluation(nullInt <=> normalInt, false)
- checkEvaluation(nullInt <=> nullInt, true)
+ checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
+ checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
+ checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index eb64684ae0..35ca0b4c7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
test("nanvl") {
val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
- Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil),
+ Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
- StructField("c", DoubleType), StructField("d", DoubleType))))
+ StructField("c", DoubleType), StructField("d", DoubleType),
+ StructField("e", FloatType), StructField("f", IntegerType))))
checkAnswer(
testData.select(
- nanvl($"a", lit(5)), nanvl($"b", lit(10)),
- nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))),
- Row(null, 3.0, null, Double.PositiveInfinity)
+ nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"),
+ nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)),
+ nanvl($"b", $"e"), nanvl($"e", $"f")),
+ Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
)
testData.registerTempTable("t")
checkAnswer(
- ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"),
- Row(null, 3.0, null, Double.PositiveInfinity)
+ ctx.sql(
+ "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
+ " nanvl(b, e), nanvl(e, f) from t"),
+ Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
)
}