aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala19
4 files changed, 22 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 387e555254..a5b5b91e4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -290,11 +290,6 @@ object TypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
- a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
- case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
- a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
-
case a @ BinaryArithmetic(left @ StringType(), right) =>
a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType()) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index a63d1770f3..77ea29ead9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -182,8 +182,7 @@ class AnalysisSuite extends AnalysisTest {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- // StringType will be promoted into Decimal(38, 18)
- assert(pl(3).dataType == DecimalType(38, 22))
+ assert(pl(3).dataType == DoubleType)
assert(pl(4).dataType == DoubleType)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 500d8ff55a..9f35c02d48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -446,13 +446,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- Row(BigDecimal("92233720368547758071.2"))
+ Row(92233720368547758071.2)
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
- Row(new java.math.BigDecimal("92233720368547758071.2"))
+ Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 4b51f021bf..2a9b06b75e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1560,4 +1560,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(sql("SELECT * FROM tbl"), Row(1, "a"))
}
}
+
+ test("spark-15557 promote string test") {
+ withTable("tbl") {
+ sql("CREATE TABLE tbl(c1 string, c2 string)")
+ sql("insert into tbl values ('3', '2.3')")
+ checkAnswer(
+ sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"),
+ Row(204.0)
+ )
+ checkAnswer(
+ sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"),
+ Row(234.6)
+ )
+ checkAnswer(
+ sql("select (cast(99 as decimal(19,6)) + c1) * c2 from tbl"),
+ Row(234.6)
+ )
+ }
+ }
}