diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2016-05-12 22:31:14 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-05-12 22:31:14 -0700 |
commit | 51841d77d99a858f8fa1256e923b0364b9b28fa0 (patch) | |
tree | 8c190a69054ed9ed4da636db7029a0eef7a29188 | |
parent | eda2800d44843b6478e22d2c99bca4af7e9c9613 (diff) | |
download | spark-51841d77d99a858f8fa1256e923b0364b9b28fa0.tar.gz spark-51841d77d99a858f8fa1256e923b0364b9b28fa0.tar.bz2 spark-51841d77d99a858f8fa1256e923b0364b9b28fa0.zip |
[SPARK-13866] [SQL] Handle decimal type in CSV inference at CSV data source.
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-13866
This PR adds the support to infer `DecimalType`.
Here are the rules between `IntegerType`, `LongType` and `DecimalType`.
#### Infering Types
1. `IntegerType` and then `LongType`are tried first.
```scala
Int.MaxValue => IntegerType
Long.MaxValue => LongType
```
2. If it fails, try `DecimalType`.
```scala
(Long.MaxValue + 1) => DecimalType(20, 0)
```
This does not try to infer this as `DecimalType` when scale is less than 0.
3. if it fails, try `DoubleType`
```scala
0.1 => DoubleType // This is failed to be inferred as `DecimalType` because it has the scale, 1.
```
#### Compatible Types (Merging Types)
For merging types, this is the same with JSON data source. If `DecimalType` is not capable, then it becomes `DoubleType`
## How was this patch tested?
Unit tests were used and `./dev/run_tests` for code style test.
Author: hyukjinkwon <gurwls223@gmail.com>
Author: Hyukjin Kwon <gurwls223@gmail.com>
Closes #11724 from HyukjinKwon/SPARK-13866.
4 files changed, 81 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index cfd66af188..05c8d8ee15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.Locale import scala.util.control.Exception._ @@ -85,6 +85,7 @@ private[csv] object CSVInferSchema { case NullType => tryParseInteger(field, options) case IntegerType => tryParseInteger(field, options) case LongType => tryParseLong(field, options) + case _: DecimalType => tryParseDecimal(field, options) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) @@ -107,10 +108,28 @@ private[csv] object CSVInferSchema { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDouble(field, options) + tryParseDecimal(field, options) } } + private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + val decimalTry = allCatch opt { + // `BigDecimal` conversion can fail when the `field` is not a form of number. + val bigDecimal = new BigDecimal(field) + // Because many other formats do not support decimal, it reduces the cases for + // decimals by disallowing values having scale (eg. `1.1`). + if (bigDecimal.scale <= 0) { + // `DecimalType` conversion can fail when + // 1. The precision is bigger than 38. + // 2. scale is bigger than precision. + DecimalType(bigDecimal.precision, bigDecimal.scale) + } else { + tryParseDouble(field, options) + } + } + decimalTry.getOrElse(tryParseDouble(field, options)) + } + private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined) { DoubleType @@ -170,6 +189,33 @@ private[csv] object CSVInferSchema { val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) + // These two cases below deal with when `DecimalType` is larger than `IntegralType`. + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // These two cases below deal with when `IntegralType` is larger than `DecimalType`. + case (t1: IntegralType, t2: DecimalType) => + findTightestCommonType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + findTightestCommonType(t1, DecimalType.forType(t2)) + + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + Some(DoubleType) + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + Some(DoubleType) + } else { + Some(DecimalType(range + scale, scale)) + } + case _ => None } } diff --git a/sql/core/src/test/resources/decimal.csv b/sql/core/src/test/resources/decimal.csv new file mode 100644 index 0000000000..870f6aaf1b --- /dev/null +++ b/sql/core/src/test/resources/decimal.csv @@ -0,0 +1,7 @@ +~ decimal field has integer, integer and decimal values. The last value cannot fit to a long +~ long field has integer, long and integer values. +~ double field has double, double and decimal values. +decimal,long,double +1,1,0.1 +1,9223372036854775807,1.0 +92233720368547758070,1,92233720368547758070 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index daf85be56f..dbe3af49c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.csv -import java.text.SimpleDateFormat - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { @@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { @@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) + assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ae91e0f606..27d6dc9197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" + private val decimalFile = "decimal.csv" private val simpleSparseFile = "simple_sparse.csv" private val numbersFile = "numbers.csv" private val datesFile = "dates.csv" @@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema === expectedSchema) } + test("test inferring decimals") { + val result = sqlContext.read + .format("csv") + .option("comment", "~") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(decimalFile)) + val expectedSchema = StructType(List( + StructField("decimal", DecimalType(20, 0), nullable = true), + StructField("long", LongType, nullable = true), + StructField("double", DoubleType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { val cars = spark.read .format("csv") |