diff options
author | Yijie Shen <henry.yijieshen@gmail.com> | 2015-08-03 00:15:24 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-03 00:15:24 -0700 |
commit | 98d6d9c7a996f5456eb2653bb96985a1a05f4ce1 (patch) | |
tree | f773995dce2ae3ce065b05a0976fa2c98b732e43 /sql/catalyst/src/main | |
parent | 608353c8e8e50461fafff91a2c885dca8af3aaa8 (diff) | |
download | spark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.tar.gz spark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.tar.bz2 spark-98d6d9c7a996f5456eb2653bb96985a1a05f4ce1.zip |
[SPARK-9549][SQL] fix bugs in expressions
JIRA: https://issues.apache.org/jira/browse/SPARK-9549
This PR fix the following bugs:
1. `UnaryMinus`'s codegen version would fail to compile when the input is `Long.MinValue`
2. `BinaryComparison` would fail to compile in codegen mode when comparing Boolean types.
3. `AddMonth` would fail if passed a huge negative month, which would lead accessing negative index of `monthDays` array.
4. `Nanvl` with different type operands.
Author: Yijie Shen <henry.yijieshen@gmail.com>
Closes #7882 from yjshen/minor_bug_fix and squashes the following commits:
41bbd2c [Yijie Shen] fix bug in Nanvl type coercion
3dee204 [Yijie Shen] address comments
4fa5de0 [Yijie Shen] fix bugs in expressions
Diffstat (limited to 'sql/catalyst/src/main')
4 files changed, 18 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 603afc4032..422d423747 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -562,6 +562,11 @@ object HiveTypeCoercion { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } + + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => + NaNvl(l, Cast(r, DoubleType)) + case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => + NaNvl(Cast(l, DoubleType), r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6f8f4dd230..0891b55494 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") - case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { + val originValue = ctx.freshName("origin") + // codegen would fail to compile if we just write (-($c)) + // for example, we could not write --9223372036854775808L in code + s""" + ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); + ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue)); + """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ab7d3afce8..b69bbabee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType) + && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { // faster version diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6a98f4d9c5..f645eb5f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -614,8 +614,9 @@ object DateTimeUtils { */ def dateAddMonths(days: Int, months: Int): Int = { val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months - val currentMonthInYear = absoluteMonth % 12 - val currentYear = absoluteMonth / 12 + val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 + val currentMonthInYear = nonNegativeMonth % 12 + val currentYear = nonNegativeMonth / 12 val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay @@ -626,7 +627,7 @@ object DateTimeUtils { } else { dayOfMonth } - firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1 } /** |