aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2017-01-03 23:06:50 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-03 23:06:50 +0800
commit7a2b5f93bc3d3224470837ed3323964ba7cb1dca (patch)
tree12853f95a3e2bfa2887a1b5bb8b62326521f6d75
parent52636226dc8cb7fcf00381d65e280d651b25a382 (diff)
downloadspark-7a2b5f93bc3d3224470837ed3323964ba7cb1dca.tar.gz
spark-7a2b5f93bc3d3224470837ed3323964ba7cb1dca.tar.bz2
spark-7a2b5f93bc3d3224470837ed3323964ba7cb1dca.zip
[SPARK-18877][SQL] `CSVInferSchema.inferField` on DecimalType should find a common type with `typeSoFar`
## What changes were proposed in this pull request? CSV type inferencing causes `IllegalArgumentException` on decimal numbers with heterogeneous precisions and scales because the current logic uses the last decimal type in a **partition**. Specifically, `inferRowType`, the **seqOp** of **aggregate**, returns the last decimal type. This PR fixes it to use `findTightestCommonType`. **decimal.csv** ``` 9.03E+12 1.19E+11 ``` **BEFORE** ```scala scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").printSchema root |-- _c0: decimal(3,-9) (nullable = true) scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").show 16/12/16 14:32:49 ERROR Executor: Exception in task 0.0 in stage 4.0 (TID 4) java.lang.IllegalArgumentException: requirement failed: Decimal precision 4 exceeds max precision 3 ``` **AFTER** ```scala scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").printSchema root |-- _c0: decimal(4,-9) (nullable = true) scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").show +---------+ | _c0| +---------+ |9.030E+12| | 1.19E+11| +---------+ ``` ## How was this patch tested? Pass the newly add test case. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #16320 from dongjoon-hyun/SPARK-18877.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala17
2 files changed, 20 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 88c608add1..adc92fe5a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -85,7 +85,9 @@ private[csv] object CSVInferSchema {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
- case _: DecimalType => tryParseDecimal(field, options)
+ case _: DecimalType =>
+ // DecimalTypes have different precisions and scales, so we try to find the common type.
+ findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 93f752d107..8620bb9f65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -114,4 +114,21 @@ class CSVInferSchemaSuite extends SparkFunSuite {
val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"))
assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
}
+
+ test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") {
+ val options = new CSVOptions(Map.empty[String, String])
+
+ // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
+ assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) ==
+ DecimalType(4, -9))
+
+ // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20.
+ val value = "12345678901234567890.01234567890123456789"
+ assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType)
+
+ // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
+ assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0))
+ assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options)
+ == StringType)
+ }
}