aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHossein <hossein@databricks.com>2016-04-30 18:11:56 -0700
committerReynold Xin <rxin@databricks.com>2016-04-30 18:12:03 -0700
commit507bea5ca6d95c995f8152b8473713c136e23754 (patch)
tree663b7920df1d4fbe0ae5357a029140729a50e40b /sql
parent0182d9599d15f70eeb6288bf9294fa677004bd14 (diff)
downloadspark-507bea5ca6d95c995f8152b8473713c136e23754.tar.gz
spark-507bea5ca6d95c995f8152b8473713c136e23754.tar.bz2
spark-507bea5ca6d95c995f8152b8473713c136e23754.zip
[SPARK-14143] Options for parsing NaNs, Infinity and nulls for numeric types
1. Adds the following options for parsing NaNs: nanValue 2. Adds the following options for parsing infinity: positiveInf, negativeInf. `TypeCast.castTo` is unit tested and an end-to-end test is added to `CSVSuite` Author: Hossein <hossein@databricks.com> Closes #11947 from falaki/SPARK-14143.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala3
-rw-r--r--sql/core/src/test/resources/numbers.csv9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala83
6 files changed, 174 insertions, 42 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index a26a8084b6..cfd66af188 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -190,40 +190,61 @@ private[csv] object CSVTypeCast {
datum: String,
castType: DataType,
nullable: Boolean = true,
- nullValue: String = "",
- dateFormat: SimpleDateFormat = null): Any = {
+ options: CSVOptions = CSVOptions()): Any = {
- if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) {
- null
- } else {
- castType match {
- case _: ByteType => datum.toByte
- case _: ShortType => datum.toShort
- case _: IntegerType => datum.toInt
- case _: LongType => datum.toLong
- case _: FloatType => Try(datum.toFloat)
- .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
- case _: DoubleType => Try(datum.toDouble)
- .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
- case _: BooleanType => datum.toBoolean
- case dt: DecimalType =>
+ castType match {
+ case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte
+ case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort
+ case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt
+ case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong
+ case _: FloatType =>
+ if (datum == options.nullValue && nullable) {
+ null
+ } else if (datum == options.nanValue) {
+ Float.NaN
+ } else if (datum == options.negativeInf) {
+ Float.NegativeInfinity
+ } else if (datum == options.positiveInf) {
+ Float.PositiveInfinity
+ } else {
+ Try(datum.toFloat)
+ .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
+ }
+ case _: DoubleType =>
+ if (datum == options.nullValue && nullable) {
+ null
+ } else if (datum == options.nanValue) {
+ Double.NaN
+ } else if (datum == options.negativeInf) {
+ Double.NegativeInfinity
+ } else if (datum == options.positiveInf) {
+ Double.PositiveInfinity
+ } else {
+ Try(datum.toDouble)
+ .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
+ }
+ case _: BooleanType => datum.toBoolean
+ case dt: DecimalType =>
+ if (datum == options.nullValue && nullable) {
+ null
+ } else {
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
- case _: TimestampType if dateFormat != null =>
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.
- dateFormat.parse(datum).getTime * 1000L
- case _: TimestampType =>
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.
- DateTimeUtils.stringToTime(datum).getTime * 1000L
- case _: DateType if dateFormat != null =>
- DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime)
- case _: DateType =>
- DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
- case _: StringType => UTF8String.fromString(datum)
- case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
- }
+ }
+ case _: TimestampType if options.dateFormat != null =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.
+ options.dateFormat.parse(datum).getTime * 1000L
+ case _: TimestampType =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.
+ DateTimeUtils.stringToTime(datum).getTime * 1000L
+ case _: DateType if options.dateFormat != null =>
+ DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)
+ case _: DateType =>
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+ case _: StringType => UTF8String.fromString(datum)
+ case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index b87d19f7cf..e4fd09462f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -90,6 +90,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
val nullValue = parameters.getOrElse("nullValue", "")
+ val nanValue = parameters.getOrElse("nanValue", "NaN")
+
+ val positiveInf = parameters.getOrElse("positiveInf", "Inf")
+ val negativeInf = parameters.getOrElse("negativeInf", "-Inf")
+
+
val compressionCodec: Option[String] = {
val name = parameters.get("compression").orElse(parameters.get("codec"))
name.map(CompressionCodecs.getCodecClassName)
@@ -111,3 +117,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
val rowSeparator = "\n"
}
+
+object CSVOptions {
+
+ def apply(): CSVOptions = new CSVOptions(Map.empty)
+
+ def apply(paramName: String, paramValue: String): CSVOptions = {
+ new CSVOptions(Map(paramName -> paramValue))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 9a723630de..4f2d4387b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -99,8 +99,7 @@ object CSVRelation extends Logging {
indexSafeTokens(index),
field.dataType,
field.nullable,
- params.nullValue,
- params.dateFormat)
+ params)
if (subIndex < requiredSize) {
row(subIndex) = value
}
diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/numbers.csv
new file mode 100644
index 0000000000..af8feac784
--- /dev/null
+++ b/sql/core/src/test/resources/numbers.csv
@@ -0,0 +1,9 @@
+int,long,float,double
+8,1000000,1.042,23848545.0374
+--,34232323,98.343,184721.23987223
+34,--,98.343,184721.23987223
+34,43323123,--,184721.23987223
+34,43323123,223823.9484,--
+34,43323123,223823.NAN,NAN
+34,43323123,223823.INF,INF
+34,43323123,223823.-INF,-INF
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 8847c7632f..07f00a0868 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -46,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"
+ private val numbersFile = "numbers.csv"
private val datesFile = "dates.csv"
private val unescapedQuotesFile = "unescaped-quotes.csv"
@@ -535,4 +536,26 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = false, checkTypes = false)
}
+
+ test("nulls, NaNs and Infinity values can be parsed") {
+ val numbers = sqlContext
+ .read
+ .format("csv")
+ .schema(StructType(List(
+ StructField("int", IntegerType, true),
+ StructField("long", LongType, true),
+ StructField("float", FloatType, true),
+ StructField("double", DoubleType, true)
+ )))
+ .options(Map(
+ "header" -> "true",
+ "mode" -> "DROPMALFORMED",
+ "nullValue" -> "--",
+ "nanValue" -> "NAN",
+ "negativeInf" -> "-INF",
+ "positiveInf" -> "INF"))
+ .load(testFile(numbersFile))
+
+ assert(numbers.count() == 8)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index 8b59bc148f..26b33b24ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -29,6 +29,8 @@ import org.apache.spark.unsafe.types.UTF8String
class CSVTypeCastSuite extends SparkFunSuite {
+ private def assertNull(v: Any) = assert(v == null)
+
test("Can parse decimal type values") {
val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
val decimalValues = Seq(10.05, 1000.01, 158058049.001)
@@ -66,17 +68,21 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("Nullable types are handled") {
- assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null)
+ assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null)
}
test("String type should always return the same as the input") {
- assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString(""))
- assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString(""))
+ assert(
+ CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) ==
+ UTF8String.fromString(""))
+ assert(
+ CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) ==
+ UTF8String.fromString(""))
}
test("Throws exception for empty string with non null type") {
val exception = intercept[NumberFormatException]{
- CSVTypeCast.castTo("", IntegerType, nullable = false)
+ CSVTypeCast.castTo("", IntegerType, nullable = false, CSVOptions())
}
assert(exception.getMessage.contains("For input string: \"\""))
}
@@ -90,12 +96,12 @@ class CSVTypeCastSuite extends SparkFunSuite {
assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", BooleanType) == true)
- val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
+ val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm")
val customTimestamp = "31/01/2015 00:00"
- val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime
- assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = dateFormat)
- == expectedTime * 1000L)
- assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = dateFormat) ==
+ val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime
+ assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, options) ==
+ expectedTime * 1000L)
+ assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, options) ==
DateTimeUtils.millisToDays(expectedTime))
val timestamp = "2015-01-01 00:00:00"
@@ -116,4 +122,63 @@ class CSVTypeCastSuite extends SparkFunSuite {
Locale.setDefault(originalLocale)
}
}
+
+ test("Float NaN values are parsed correctly") {
+ val floatVal: Float = CSVTypeCast.castTo(
+ "nn", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float]
+
+ // Java implements the IEEE-754 floating point standard which guarantees that any comparison
+ // against NaN will return false (except != which returns true)
+ assert(floatVal != floatVal)
+ }
+
+ test("Double NaN values are parsed correctly") {
+ val doubleVal: Double = CSVTypeCast.castTo(
+ "-", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double]
+
+ assert(doubleVal.isNaN)
+ }
+
+ test("Float infinite values can be parsed") {
+ val floatVal1 = CSVTypeCast.castTo(
+ "max", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float]
+
+ assert(floatVal1 == Float.NegativeInfinity)
+
+ val floatVal2 = CSVTypeCast.castTo(
+ "max", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float]
+
+ assert(floatVal2 == Float.PositiveInfinity)
+ }
+
+ test("Double infinite values can be parsed") {
+ val doubleVal1 = CSVTypeCast.castTo(
+ "max", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
+ ).asInstanceOf[Double]
+
+ assert(doubleVal1 == Double.NegativeInfinity)
+
+ val doubleVal2 = CSVTypeCast.castTo(
+ "max", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
+ ).asInstanceOf[Double]
+
+ assert(doubleVal2 == Double.PositiveInfinity)
+ }
+
+ test("Type-specific null values are used for casting") {
+ assertNull(
+ CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-")))
+ assertNull(
+ CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-")))
+ }
}