aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala8
2 files changed, 13 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 485b186c7c..3fa30fe240 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -150,6 +150,10 @@ private[csv] object CSVInferSchema {
}
}
+ private def isInfOrNan(field: String, options: CSVOptions): Boolean = {
+ field == options.nanValue || field == options.negativeInf || field == options.positiveInf
+ }
+
private def tryParseInteger(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toInt).isDefined) {
IntegerType
@@ -185,7 +189,7 @@ private[csv] object CSVInferSchema {
}
private def tryParseDouble(field: String, options: CSVOptions): DataType = {
- if ((allCatch opt field.toDouble).isDefined) {
+ if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) {
DoubleType
} else {
tryParseTimestamp(field, options)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 8620bb9f65..d8c6c25504 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -131,4 +131,12 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options)
== StringType)
}
+
+ test("DoubleType should be infered when user defined nan/inf are provided") {
+ val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf",
+ "positiveInf" -> "inf"))
+ assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType)
+ assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType)
+ assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType)
+ }
}