aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-05-12 22:31:14 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-12 22:31:14 -0700
commit51841d77d99a858f8fa1256e923b0364b9b28fa0 (patch)
tree8c190a69054ed9ed4da636db7029a0eef7a29188
parenteda2800d44843b6478e22d2c99bca4af7e9c9613 (diff)
downloadspark-51841d77d99a858f8fa1256e923b0364b9b28fa0.tar.gz
spark-51841d77d99a858f8fa1256e923b0364b9b28fa0.tar.bz2
spark-51841d77d99a858f8fa1256e923b0364b9b28fa0.zip
[SPARK-13866] [SQL] Handle decimal type in CSV inference at CSV data source.
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13866 This PR adds the support to infer `DecimalType`. Here are the rules between `IntegerType`, `LongType` and `DecimalType`. #### Infering Types 1. `IntegerType` and then `LongType`are tried first. ```scala Int.MaxValue => IntegerType Long.MaxValue => LongType ``` 2. If it fails, try `DecimalType`. ```scala (Long.MaxValue + 1) => DecimalType(20, 0) ``` This does not try to infer this as `DecimalType` when scale is less than 0. 3. if it fails, try `DoubleType` ```scala 0.1 => DoubleType // This is failed to be inferred as `DecimalType` because it has the scale, 1. ``` #### Compatible Types (Merging Types) For merging types, this is the same with JSON data source. If `DecimalType` is not capable, then it becomes `DoubleType` ## How was this patch tested? Unit tests were used and `./dev/run_tests` for code style test. Author: hyukjinkwon <gurwls223@gmail.com> Author: Hyukjin Kwon <gurwls223@gmail.com> Closes #11724 from HyukjinKwon/SPARK-13866.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala50
-rw-r--r--sql/core/src/test/resources/decimal.csv7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala15
4 files changed, 81 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index cfd66af188..05c8d8ee15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.text.{NumberFormat, SimpleDateFormat}
+import java.text.NumberFormat
import java.util.Locale
import scala.util.control.Exception._
@@ -85,6 +85,7 @@ private[csv] object CSVInferSchema {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
+ case _: DecimalType => tryParseDecimal(field, options)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
@@ -107,10 +108,28 @@ private[csv] object CSVInferSchema {
if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
- tryParseDouble(field, options)
+ tryParseDecimal(field, options)
}
}
+ private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
+ val decimalTry = allCatch opt {
+ // `BigDecimal` conversion can fail when the `field` is not a form of number.
+ val bigDecimal = new BigDecimal(field)
+ // Because many other formats do not support decimal, it reduces the cases for
+ // decimals by disallowing values having scale (eg. `1.1`).
+ if (bigDecimal.scale <= 0) {
+ // `DecimalType` conversion can fail when
+ // 1. The precision is bigger than 38.
+ // 2. scale is bigger than precision.
+ DecimalType(bigDecimal.precision, bigDecimal.scale)
+ } else {
+ tryParseDouble(field, options)
+ }
+ }
+ decimalTry.getOrElse(tryParseDouble(field, options))
+ }
+
private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
@@ -170,6 +189,33 @@ private[csv] object CSVInferSchema {
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
+ // These two cases below deal with when `DecimalType` is larger than `IntegralType`.
+ case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
+ Some(t2)
+ case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
+ Some(t1)
+
+ // These two cases below deal with when `IntegralType` is larger than `DecimalType`.
+ case (t1: IntegralType, t2: DecimalType) =>
+ findTightestCommonType(DecimalType.forType(t1), t2)
+ case (t1: DecimalType, t2: IntegralType) =>
+ findTightestCommonType(t1, DecimalType.forType(t2))
+
+ // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+ // in most case, also have better precision.
+ case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+ Some(DoubleType)
+
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > 38) {
+ // DecimalType can't support precision > 38
+ Some(DoubleType)
+ } else {
+ Some(DecimalType(range + scale, scale))
+ }
+
case _ => None
}
}
diff --git a/sql/core/src/test/resources/decimal.csv b/sql/core/src/test/resources/decimal.csv
new file mode 100644
index 0000000000..870f6aaf1b
--- /dev/null
+++ b/sql/core/src/test/resources/decimal.csv
@@ -0,0 +1,7 @@
+~ decimal field has integer, integer and decimal values. The last value cannot fit to a long
+~ long field has integer, long and integer values.
+~ double field has double, double and decimal values.
+decimal,long,double
+1,1,0.1
+1,9223372036854775807,1.0
+92233720368547758070,1,92233720368547758070
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index daf85be56f..dbe3af49c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.text.SimpleDateFormat
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
@@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType)
+
+ val textValueOne = Long.MaxValue.toString + "0"
+ val decimalValueOne = new java.math.BigDecimal(textValueOne)
+ val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
+ assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne)
}
test("String fields types are inferred correctly from other types") {
@@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType)
+
+ val textValueOne = Long.MaxValue.toString + "0"
+ val decimalValueOne = new java.math.BigDecimal(textValueOne)
+ val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
+ assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne)
}
test("Timestamp field types are inferred correctly via custom data format") {
@@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
+ assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
}
test("Merging Nulltypes should yield Nulltype.") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index ae91e0f606..27d6dc9197 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
+ private val decimalFile = "decimal.csv"
private val simpleSparseFile = "simple_sparse.csv"
private val numbersFile = "numbers.csv"
private val datesFile = "dates.csv"
@@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema === expectedSchema)
}
+ test("test inferring decimals") {
+ val result = sqlContext.read
+ .format("csv")
+ .option("comment", "~")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .load(testFile(decimalFile))
+ val expectedSchema = StructType(List(
+ StructField("decimal", DecimalType(20, 0), nullable = true),
+ StructField("long", LongType, nullable = true),
+ StructField("double", DoubleType, nullable = true)))
+ assert(result.schema === expectedSchema)
+ }
+
test("test with alternative delimiter and quote") {
val cars = spark.read
.format("csv")