aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-19 14:03:47 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-19 14:04:09 -0700
commitd9dfd43d463cd5e3c9e72197850382ad8897b8ac (patch)
tree310e21fafec57c0d1ae958e620666cb75733f2c5
parent77269fcb5143598ff4537e821a073fb8e5b22562 (diff)
downloadspark-d9dfd43d463cd5e3c9e72197850382ad8897b8ac.tar.gz
spark-d9dfd43d463cd5e3c9e72197850382ad8897b8ac.tar.bz2
spark-d9dfd43d463cd5e3c9e72197850382ad8897b8ac.zip
[SPARK-10090] [SQL] fix decimal scale of division
We should rounding the result of multiply/division of decimal to expected precision/scale, also check overflow. Author: Davies Liu <davies@databricks.com> Closes #8287 from davies/decimal_division. (cherry picked from commit 1f4c4fe6dfd8cc52b5fddfd67a31a77edbb1a036) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala63
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala23
6 files changed, 157 insertions, 31 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8581d6b496..62c27ee0b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -371,8 +371,8 @@ object HiveTypeCoercion {
DecimalType.bounded(range + scale, scale)
}
- private def changePrecision(e: Expression, dataType: DataType): Expression = {
- ChangeDecimalPrecision(Cast(e, dataType))
+ private def promotePrecision(e: Expression, dataType: DataType): Expression = {
+ PromotePrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
@@ -383,36 +383,42 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e
// Skip nodes who is already promoted
- case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e
+ case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- Add(changePrecision(e1, dt), changePrecision(e2, dt))
+ CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
- Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
+ CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
- Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
+ val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+ resultType)
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
- Divide(changePrecision(e1, dt), changePrecision(e2, dt))
+ val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1),
+ max(6, s1 + p2 + 1))
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+ resultType)
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
- Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)),
+ CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
- Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType)
+ CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+ resultType)
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 616b9e0e65..2db954257b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -447,7 +447,7 @@ case class Cast(child: Expression, dataType: DataType)
case StringType => castToStringCode(from, ctx)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
- case decimal: DecimalType => castToDecimalCode(from, decimal)
+ case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
@@ -528,14 +528,18 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
- private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = {
+ private[this] def castToDecimalCode(
+ from: DataType,
+ target: DecimalType,
+ ctx: CodeGenContext): CastFunction = {
+ val tmp = ctx.freshName("tmpDecimal")
from match {
case StringType =>
(c, evPrim, evNull) =>
s"""
try {
- Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString()));
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString()));
+ ${changePrecision(tmp, target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
@@ -543,8 +547,8 @@ case class Cast(child: Expression, dataType: DataType)
case BooleanType =>
(c, evPrim, evNull) =>
s"""
- Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0);
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0);
+ ${changePrecision(tmp, target, evPrim, evNull)}
"""
case DateType =>
// date can't cast to decimal in Hive
@@ -553,29 +557,29 @@ case class Cast(child: Expression, dataType: DataType)
// Note that we lose precision here.
(c, evPrim, evNull) =>
s"""
- Decimal tmpDecimal = Decimal.apply(
+ Decimal $tmp = Decimal.apply(
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ ${changePrecision(tmp, target, evPrim, evNull)}
"""
case DecimalType() =>
(c, evPrim, evNull) =>
s"""
- Decimal tmpDecimal = $c.clone();
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ Decimal $tmp = $c.clone();
+ ${changePrecision(tmp, target, evPrim, evNull)}
"""
case x: IntegralType =>
(c, evPrim, evNull) =>
s"""
- Decimal tmpDecimal = Decimal.apply((long) $c);
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ Decimal $tmp = Decimal.apply((long) $c);
+ ${changePrecision(tmp, target, evPrim, evNull)}
"""
case x: FractionalType =>
// All other numeric types can be represented precisely as Doubles
(c, evPrim, evNull) =>
s"""
try {
- Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
- ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+ Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
+ ${changePrecision(tmp, target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index adb33e4c8d..b7be12f7aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -66,10 +66,44 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
* An expression used to wrap the children when promote the precision of DecimalType to avoid
* promote multiple times.
*/
-case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression {
+case class PromotePrecision(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
- override def prettyName: String = "change_decimal_precision"
+ override def prettyName: String = "promote_precision"
+}
+
+/**
+ * Rounds the decimal to given scale and check whether the decimal can fit in provided precision
+ * or not, returns null if not.
+ */
+case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def nullSafeEval(input: Any): Any = {
+ val d = input.asInstanceOf[Decimal].clone()
+ if (d.changePrecision(dataType.precision, dataType.scale)) {
+ d
+ } else {
+ null
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, eval => {
+ val tmp = ctx.freshName("tmp")
+ s"""
+ | Decimal $tmp = $eval.clone();
+ | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
+ | ${ev.primitive} = $tmp;
+ | } else {
+ | ${ev.isNull} = true;
+ | }
+ """.stripMargin
+ })
+ }
+
+ override def toString: String = s"CheckOverflow($child, $dataType)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index d95805c245..c988f1d1b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -267,7 +267,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale)
} else {
- Decimal(toBigDecimal + that.toBigDecimal, precision, scale)
+ Decimal(toBigDecimal + that.toBigDecimal)
}
}
@@ -275,7 +275,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale)
} else {
- Decimal(toBigDecimal - that.toBigDecimal, precision, scale)
+ Decimal(toBigDecimal - that.toBigDecimal)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
new file mode 100644
index 0000000000..511f030790
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{LongType, DecimalType, Decimal}
+
+
+class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("UnscaledValue") {
+ val d1 = Decimal("10.1")
+ checkEvaluation(UnscaledValue(Literal(d1)), 101L)
+ val d2 = Decimal(101, 3, 1)
+ checkEvaluation(UnscaledValue(Literal(d2)), 101L)
+ checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null)
+ }
+
+ test("MakeDecimal") {
+ checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+ checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+ }
+
+ test("PromotePrecision") {
+ val d1 = Decimal("10.1")
+ checkEvaluation(PromotePrecision(Literal(d1)), d1)
+ val d2 = Decimal(101, 3, 1)
+ checkEvaluation(PromotePrecision(Literal(d2)), d2)
+ checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null)
+ }
+
+ test("CheckOverflow") {
+ val d1 = Decimal("10.1")
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
+ checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
+
+ val d2 = Decimal(101, 3, 1)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
+ checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
+
+ checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c329fdb2a6..141468ca00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,16 +17,17 @@
package org.apache.spark.sql
+import java.math.MathContext
import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
@@ -1608,6 +1609,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
+ test("decimal precision with multiply/division") {
+ checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90")))
+ checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000")))
+ checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000")))
+ checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
+ Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
+ checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
+ Row(null))
+
+ checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
+ checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
+ checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
+ checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
+ Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38))))
+ checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
+ Row(null))
+ }
+
test("external sorting updates peak execution memory") {
withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
val sc = sqlContext.sparkContext