diff options
author | Hossein <hossein@databricks.com> | 2016-04-30 18:11:56 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-30 18:12:03 -0700 |
commit | 507bea5ca6d95c995f8152b8473713c136e23754 (patch) | |
tree | 663b7920df1d4fbe0ae5357a029140729a50e40b /sql | |
parent | 0182d9599d15f70eeb6288bf9294fa677004bd14 (diff) | |
download | spark-507bea5ca6d95c995f8152b8473713c136e23754.tar.gz spark-507bea5ca6d95c995f8152b8473713c136e23754.tar.bz2 spark-507bea5ca6d95c995f8152b8473713c136e23754.zip |
[SPARK-14143] Options for parsing NaNs, Infinity and nulls for numeric types
1. Adds the following options for parsing NaNs: nanValue
2. Adds the following options for parsing infinity: positiveInf, negativeInf.
`TypeCast.castTo` is unit tested and an end-to-end test is added to `CSVSuite`
Author: Hossein <hossein@databricks.com>
Closes #11947 from falaki/SPARK-14143.
Diffstat (limited to 'sql')
6 files changed, 174 insertions, 42 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index a26a8084b6..cfd66af188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -190,40 +190,61 @@ private[csv] object CSVTypeCast { datum: String, castType: DataType, nullable: Boolean = true, - nullValue: String = "", - dateFormat: SimpleDateFormat = null): Any = { + options: CSVOptions = CSVOptions()): Any = { - if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { - null - } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - case _: DoubleType => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - case _: BooleanType => datum.toBoolean - case dt: DecimalType => + castType match { + case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte + case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort + case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt + case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong + case _: FloatType => + if (datum == options.nullValue && nullable) { + null + } else if (datum == options.nanValue) { + Float.NaN + } else if (datum == options.negativeInf) { + Float.NegativeInfinity + } else if (datum == options.positiveInf) { + Float.PositiveInfinity + } else { + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + } + case _: DoubleType => + if (datum == options.nullValue && nullable) { + null + } else if (datum == options.nanValue) { + Double.NaN + } else if (datum == options.negativeInf) { + Double.NegativeInfinity + } else if (datum == options.positiveInf) { + Double.PositiveInfinity + } else { + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + } + case _: BooleanType => datum.toBoolean + case dt: DecimalType => + if (datum == options.nullValue && nullable) { + null + } else { val value = new BigDecimal(datum.replaceAll(",", "")) Decimal(value, dt.precision, dt.scale) - case _: TimestampType if dateFormat != null => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - dateFormat.parse(datum).getTime * 1000L - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(datum).getTime * 1000L - case _: DateType if dateFormat != null => - DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime) - case _: DateType => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - case _: StringType => UTF8String.fromString(datum) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") - } + } + case _: TimestampType if options.dateFormat != null => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + options.dateFormat.parse(datum).getTime * 1000L + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + DateTimeUtils.stringToTime(datum).getTime * 1000L + case _: DateType if options.dateFormat != null => + DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime) + case _: DateType => + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + case _: StringType => UTF8String.fromString(datum) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index b87d19f7cf..e4fd09462f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -90,6 +90,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val nullValue = parameters.getOrElse("nullValue", "") + val nanValue = parameters.getOrElse("nanValue", "NaN") + + val positiveInf = parameters.getOrElse("positiveInf", "Inf") + val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + + val compressionCodec: Option[String] = { val name = parameters.get("compression").orElse(parameters.get("codec")) name.map(CompressionCodecs.getCodecClassName) @@ -111,3 +117,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val rowSeparator = "\n" } + +object CSVOptions { + + def apply(): CSVOptions = new CSVOptions(Map.empty) + + def apply(paramName: String, paramValue: String): CSVOptions = { + new CSVOptions(Map(paramName -> paramValue)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 9a723630de..4f2d4387b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -99,8 +99,7 @@ object CSVRelation extends Logging { indexSafeTokens(index), field.dataType, field.nullable, - params.nullValue, - params.dateFormat) + params) if (subIndex < requiredSize) { row(subIndex) = value } diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/numbers.csv new file mode 100644 index 0000000000..af8feac784 --- /dev/null +++ b/sql/core/src/test/resources/numbers.csv @@ -0,0 +1,9 @@ +int,long,float,double +8,1000000,1.042,23848545.0374 +--,34232323,98.343,184721.23987223 +34,--,98.343,184721.23987223 +34,43323123,--,184721.23987223 +34,43323123,223823.9484,-- +34,43323123,223823.NAN,NAN +34,43323123,223823.INF,INF +34,43323123,223823.-INF,-INF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 8847c7632f..07f00a0868 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -46,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" + private val numbersFile = "numbers.csv" private val datesFile = "dates.csv" private val unescapedQuotesFile = "unescaped-quotes.csv" @@ -535,4 +536,26 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = false, checkTypes = false) } + + test("nulls, NaNs and Infinity values can be parsed") { + val numbers = sqlContext + .read + .format("csv") + .schema(StructType(List( + StructField("int", IntegerType, true), + StructField("long", LongType, true), + StructField("float", FloatType, true), + StructField("double", DoubleType, true) + ))) + .options(Map( + "header" -> "true", + "mode" -> "DROPMALFORMED", + "nullValue" -> "--", + "nanValue" -> "NAN", + "negativeInf" -> "-INF", + "positiveInf" -> "INF")) + .load(testFile(numbersFile)) + + assert(numbers.count() == 8) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 8b59bc148f..26b33b24ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.unsafe.types.UTF8String class CSVTypeCastSuite extends SparkFunSuite { + private def assertNull(v: Any) = assert(v == null) + test("Can parse decimal type values") { val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") val decimalValues = Seq(10.05, 1000.01, 158058049.001) @@ -66,17 +68,21 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null) + assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null) } test("String type should always return the same as the input") { - assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString("")) - assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString("")) + assert( + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) == + UTF8String.fromString("")) + assert( + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + UTF8String.fromString("")) } test("Throws exception for empty string with non null type") { val exception = intercept[NumberFormatException]{ - CSVTypeCast.castTo("", IntegerType, nullable = false) + CSVTypeCast.castTo("", IntegerType, nullable = false, CSVOptions()) } assert(exception.getMessage.contains("For input string: \"\"")) } @@ -90,12 +96,12 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" - val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime - assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = dateFormat) - == expectedTime * 1000L) - assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = dateFormat) == + val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime + assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, options) == + expectedTime * 1000L) + assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, options) == DateTimeUtils.millisToDays(expectedTime)) val timestamp = "2015-01-01 00:00:00" @@ -116,4 +122,63 @@ class CSVTypeCastSuite extends SparkFunSuite { Locale.setDefault(originalLocale) } } + + test("Float NaN values are parsed correctly") { + val floatVal: Float = CSVTypeCast.castTo( + "nn", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float] + + // Java implements the IEEE-754 floating point standard which guarantees that any comparison + // against NaN will return false (except != which returns true) + assert(floatVal != floatVal) + } + + test("Double NaN values are parsed correctly") { + val doubleVal: Double = CSVTypeCast.castTo( + "-", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double] + + assert(doubleVal.isNaN) + } + + test("Float infinite values can be parsed") { + val floatVal1 = CSVTypeCast.castTo( + "max", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float] + + assert(floatVal1 == Float.NegativeInfinity) + + val floatVal2 = CSVTypeCast.castTo( + "max", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float] + + assert(floatVal2 == Float.PositiveInfinity) + } + + test("Double infinite values can be parsed") { + val doubleVal1 = CSVTypeCast.castTo( + "max", DoubleType, nullable = true, CSVOptions("negativeInf", "max") + ).asInstanceOf[Double] + + assert(doubleVal1 == Double.NegativeInfinity) + + val doubleVal2 = CSVTypeCast.castTo( + "max", DoubleType, nullable = true, CSVOptions("positiveInf", "max") + ).asInstanceOf[Double] + + assert(doubleVal2 == Double.PositiveInfinity) + } + + test("Type-specific null values are used for casting") { + assertNull( + CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) + } } |