aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-28 22:51:08 -0700
committerReynold Xin <rxin@databricks.com>2015-07-28 22:51:08 -0700
commit15667a0afa5fb17f4cc6fbf32b2ddb573630f20a (patch)
tree660096b33ec70e315fc9e6fdee5f5eb6e1f39c5c /sql
parent6309b93467b06f27cd76d4662b51b47de100c677 (diff)
downloadspark-15667a0afa5fb17f4cc6fbf32b2ddb573630f20a.tar.gz
spark-15667a0afa5fb17f4cc6fbf32b2ddb573630f20a.tar.bz2
spark-15667a0afa5fb17f4cc6fbf32b2ddb573630f20a.zip
[SPARK-9281] [SQL] use decimal or double when parsing SQL
Right now, we use double to parse all the float number in SQL. When it's used in expression together with DecimalType, it will turn the decimal into double as well. Also it will loss some precision when using double. This PR change to parse float number to decimal or double, based on it's using scientific notation or not, see https://msdn.microsoft.com/en-us/library/ms179899.aspx This is a break change, should we doc it somewhere? Author: Davies Liu <davies@databricks.com> Closes #7642 from davies/parse_decimal and squashes the following commits: 1f576d9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal 5e142b6 [Davies Liu] fix scala style eca99de [Davies Liu] fix tests 2afe702 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal f4a320b [Davies Liu] Update SqlParser.scala 1c48e34 [Davies Liu] use decimal or double when parsing SQL
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala50
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala14
6 files changed, 62 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b423f0fa04..e5f115f74b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val numericLiteral: Parser[Literal] =
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
| sign.? ~ unsignedFloat ^^ {
- // TODO(davies): some precisions may loss, we should create decimal literal
- case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
+ case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f))
}
)
@@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
+ private def toDecimalOrDouble(value: String): Any = {
+ val decimal = BigDecimal(value)
+ // follow the behavior in MS SQL Server
+ // https://msdn.microsoft.com/en-us/library/ms179899.aspx
+ if (value.contains('E') || value.contains('e')) {
+ decimal.doubleValue()
+ } else {
+ decimal.underlying()
+ }
+ }
+
protected lazy val baseExpression: Parser[Expression] =
( "*" ^^^ UnresolvedStar(None)
| ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e052750344..ecc48986e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -109,13 +109,35 @@ object HiveTypeCoercion {
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
- private def findTightestCommonType(types: Seq[DataType]) = {
+ private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}
+ private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
+ case (t1: DecimalType, t2: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(t1, t2))
+ case (t: IntegralType, d: DecimalType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (d: DecimalType, t: IntegralType) =>
+ Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+ case (t: FractionalType, d: DecimalType) =>
+ Some(DoubleType)
+ case (d: DecimalType, t: FractionalType) =>
+ Some(DoubleType)
+ case _ =>
+ findTightestCommonTypeToString(t1, t2)
+ }
+
+ private def findWiderCommonType(types: Seq[DataType]) = {
+ types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+ case Some(d) => findWiderTypeForTwo(d, c)
+ case None => None
+ })
+ }
+
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@@ -182,20 +204,7 @@ object HiveTypeCoercion {
val castedTypes = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- (lhs.dataType, rhs.dataType) match {
- case (t1: DecimalType, t2: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(t1, t2))
- case (t: IntegralType, d: DecimalType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (d: DecimalType, t: IntegralType) =>
- Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
- case (t: FractionalType, d: DecimalType) =>
- Some(DoubleType)
- case (d: DecimalType, t: FractionalType) =>
- Some(DoubleType)
- case _ =>
- findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
- }
+ findWiderTypeForTwo(lhs.dataType, rhs.dataType)
case other => None
}
@@ -236,8 +245,13 @@ object HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), r) =>
- a.makeCopy(Array(Cast(left, DoubleType), r))
+ case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
+ a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
+ case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
+ a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
+
+ case a @ BinaryArithmetic(left @ StringType(), right) =>
+ a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType()) =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
@@ -543,7 +557,7 @@ object HiveTypeCoercion {
// compatible with every child column.
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
- findTightestCommonTypeAndPromoteToString(types) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 4589facb49..221b4e92f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -145,11 +145,11 @@ class AnalysisSuite extends AnalysisTest {
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList
- // StringType will be promoted into Double
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DoubleType)
+ // StringType will be promoted into Decimal(38, 18)
+ assert(pl(3).dataType == DecimalType(38, 29))
assert(pl(4).dataType == DoubleType)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 21256704a5..8cf2ef5957 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
- Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+ Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+ BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 42724ed766..d13dde1cdc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -368,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(1))
checkAnswer(
sql("SELECT COALESCE(null, 1, 1.5)"),
- Row(1.toDouble))
+ Row(BigDecimal(1)))
checkAnswer(
sql("SELECT COALESCE(null, null, null)"),
Row(null))
@@ -1234,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), Row(0.3)
+ sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
)
checkAnswer(
- sql("SELECT -0.8"), Row(-0.8)
+ sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
)
checkAnswer(
- sql("SELECT .5"), Row(0.5)
+ sql("SELECT .5"), Row(BigDecimal(0.5))
)
checkAnswer(
- sql("SELECT -.18"), Row(-0.18)
+ sql("SELECT -.18"), Row(BigDecimal(-0.18))
)
}
@@ -1279,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
checkAnswer(
- sql("SELECT -5.2"), Row(-5.2)
+ sql("SELECT -5.2"), Row(BigDecimal(-5.2))
)
checkAnswer(
- sql("SELECT +6.8"), Row(6.8)
+ sql("SELECT +6.8"), Row(BigDecimal(6.8))
)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 3ac312d6f4..f19f22fca7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData {
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
- // Widening to DoubleType
+ // Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
- Row(21474836472.2) ::
- Row(92233720368547758071.3) :: Nil
+ Row(BigDecimal("21474836472.2")) ::
+ Row(BigDecimal("92233720368547758071.3")) :: Nil
)
- // Widening to DoubleType
+ // Widening to Double
checkAnswer(
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
Row(101.2) :: Row(21474836471.2) :: Nil
@@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- Row(92233720368547758071.2)
+ Row(BigDecimal("92233720368547758071.2"))
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
- Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
+ Row(new java.math.BigDecimal("92233720368547758071.2"))
)
// String and Boolean conflict: resolve the type as string.
@@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
- Row(14.3) :: Row(92233720368547758071.2) :: Nil
+ Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil
)
}