aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDilip Biswal <dbiswal@us.ibm.com>2016-05-31 15:49:45 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-31 15:49:45 -0700
commitdfe2cbeb437a4fa69bec3eca4ac9242f3eb51c81 (patch)
treed8e4dc60ec5aaf443ac50c13456d0eb67b7eb793 /sql
parent2df6ca848e99b90acd11c3a3de342fa4d77015d6 (diff)
downloadspark-dfe2cbeb437a4fa69bec3eca4ac9242f3eb51c81.tar.gz
spark-dfe2cbeb437a4fa69bec3eca4ac9242f3eb51c81.tar.bz2
spark-dfe2cbeb437a4fa69bec3eca4ac9242f3eb51c81.zip
[SPARK-15557] [SQL] cast the string into DoubleType when it's used together with decimal
In this case, the result type of the expression becomes DECIMAL(38, 36) as we promote the individual string literals to DECIMAL(38, 18) when we handle string promotions for `BinaryArthmaticExpression`. I think we need to cast the string literals to Double type instead. I looked at the history and found that this was changed to use decimal instead of double to avoid potential loss of precision when we cast decimal to double. To double check i ran the query against hive, mysql. This query returns non NULL result for both the databases and both promote the expression to use double. Here is the output. - Hive ```SQL hive> create table l2 as select (cast(99 as decimal(19,6)) + '2') from l1; OK hive> describe l2; OK _c0 double ``` - MySQL ```SQL mysql> create table foo2 as select (cast(99 as decimal(19,6)) + '2') from test; Query OK, 1 row affected (0.01 sec) Records: 1 Duplicates: 0 Warnings: 0 mysql> describe foo2; +-----------------------------------+--------+------+-----+---------+-------+ | Field | Type | Null | Key | Default | Extra | +-----------------------------------+--------+------+-----+---------+-------+ | (cast(99 as decimal(19,6)) + '2') | double | NO | | 0 | | +-----------------------------------+--------+------+-----+---------+-------+ ``` ## How was this patch tested? Added a new test in SQLQuerySuite Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #13368 from dilipbiswal/spark-15557.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala19
4 files changed, 22 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 387e555254..a5b5b91e4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -290,11 +290,6 @@ object TypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
- a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
- case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
- a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
-
case a @ BinaryArithmetic(left @ StringType(), right) =>
a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType()) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index a63d1770f3..77ea29ead9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -182,8 +182,7 @@ class AnalysisSuite extends AnalysisTest {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- // StringType will be promoted into Decimal(38, 18)
- assert(pl(3).dataType == DecimalType(38, 22))
+ assert(pl(3).dataType == DoubleType)
assert(pl(4).dataType == DoubleType)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 500d8ff55a..9f35c02d48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -446,13 +446,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- Row(BigDecimal("92233720368547758071.2"))
+ Row(92233720368547758071.2)
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
- Row(new java.math.BigDecimal("92233720368547758071.2"))
+ Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 4b51f021bf..2a9b06b75e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1560,4 +1560,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(sql("SELECT * FROM tbl"), Row(1, "a"))
}
}
+
+ test("spark-15557 promote string test") {
+ withTable("tbl") {
+ sql("CREATE TABLE tbl(c1 string, c2 string)")
+ sql("insert into tbl values ('3', '2.3')")
+ checkAnswer(
+ sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"),
+ Row(204.0)
+ )
+ checkAnswer(
+ sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"),
+ Row(234.6)
+ )
+ checkAnswer(
+ sql("select (cast(99 as decimal(19,6)) + c1) * c2 from tbl"),
+ Row(234.6)
+ )
+ }
+ }
}