aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-25 15:19:41 -0700
committerYin Huai <yhuai@databricks.com>2015-08-25 15:20:42 -0700
commitab7d46d1d6e7e6705a3348a0cab2d05fe62951cf (patch)
tree2071e88526862bc966e3f174f2d6910f25119921
parent8925896b1eb0a13d723d38fb263d3bec0a01ec10 (diff)
downloadspark-ab7d46d1d6e7e6705a3348a0cab2d05fe62951cf.tar.gz
spark-ab7d46d1d6e7e6705a3348a0cab2d05fe62951cf.tar.bz2
spark-ab7d46d1d6e7e6705a3348a0cab2d05fe62951cf.zip
[SPARK-10215] [SQL] Fix precision of division (follow the rule in Hive)
Follow the rule in Hive for decimal division. see https://github.com/apache/hive/blob/ac755ebe26361a4647d53db2a28500f71697b276/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java#L113 cc chenghao-intel Author: Davies Liu <davies@databricks.com> Closes #8415 from davies/decimal_div2. (cherry picked from commit 7467b52ed07f174d93dfc4cb544dc4b69a2c2826) Signed-off-by: Yin Huai <yhuai@databricks.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala25
4 files changed, 39 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index a1aa2a2b2c..87c11abbad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -396,8 +396,14 @@ object HiveTypeCoercion {
resultType)
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
- val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1),
- max(6, s1 + p2 + 1))
+ var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
+ var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
+ val diff = (intDig + decDig) - DecimalType.MAX_SCALE
+ if (diff > 0) {
+ decDig -= diff / 2 + 1
+ intDig = DecimalType.MAX_SCALE - decDig
+ }
+ val resultType = DecimalType.bounded(intDig + decDig, decDig)
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 1e0cc81dae..820b336aac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.catalyst.SimpleCatalystConf
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.dsl.plans._
class AnalysisSuite extends AnalysisTest {
- import TestRelations._
+ import org.apache.spark.sql.catalyst.analysis.TestRelations._
test("union project *") {
val plan = (1 to 100)
@@ -96,7 +95,7 @@ class AnalysisSuite extends AnalysisTest {
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
// StringType will be promoted into Decimal(38, 18)
- assert(pl(3).dataType == DecimalType(38, 29))
+ assert(pl(3).dataType == DecimalType(38, 22))
assert(pl(4).dataType == DoubleType)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index fc11627da6..b4ad618c23 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -136,10 +136,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Multiply(i, u), DecimalType(38, 18))
checkType(Multiply(u, u), DecimalType(38, 36))
- checkType(Divide(u, d1), DecimalType(38, 21))
- checkType(Divide(u, d2), DecimalType(38, 24))
- checkType(Divide(u, i), DecimalType(38, 29))
- checkType(Divide(u, u), DecimalType(38, 38))
+ checkType(Divide(u, d1), DecimalType(38, 18))
+ checkType(Divide(u, d2), DecimalType(38, 19))
+ checkType(Divide(u, i), DecimalType(38, 23))
+ checkType(Divide(u, u), DecimalType(38, 18))
checkType(Remainder(d1, u), DecimalType(19, 18))
checkType(Remainder(d2, u), DecimalType(21, 18))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index aa07665c6b..9e172b2c26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1622,9 +1622,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
- Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38))))
+ Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
- Row(null))
+ Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
+ }
+
+ test("SPARK-10215 Div of Decimal returns null") {
+ val d = Decimal(1.12321)
+ val df = Seq((d, 1)).toDF("a", "b")
+
+ checkAnswer(
+ df.selectExpr("b * a / b"),
+ Seq(Row(d.toBigDecimal)))
+ checkAnswer(
+ df.selectExpr("b * a / b / b"),
+ Seq(Row(d.toBigDecimal)))
+ checkAnswer(
+ df.selectExpr("b * a + b"),
+ Seq(Row(BigDecimal(2.12321))))
+ checkAnswer(
+ df.selectExpr("b * a - b"),
+ Seq(Row(BigDecimal(0.12321))))
+ checkAnswer(
+ df.selectExpr("b * a * b"),
+ Seq(Row(d.toBigDecimal)))
}
test("precision smaller than scale") {