aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala161
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala148
3 files changed, 184 insertions, 159 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index c63aae9d83..88c608add1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -215,84 +215,121 @@ private[csv] object CSVInferSchema {
}
private[csv] object CSVTypeCast {
+ // A `ValueConverter` is responsible for converting the given value to a desired type.
+ private type ValueConverter = String => Any
/**
- * Casts given string datum to specified type.
- * Currently we do not support complex types (ArrayType, MapType, StructType).
+ * Create converters which cast each given string datum to each specified type in given schema.
+ * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`).
*
- * For string types, this is simply the datum. For other types.
+ * For string types, this is simply the datum.
+ * For other types, this is converted into the value according to the type.
* For other nullable types, returns null if it is null or equals to the value specified
* in `nullValue` option.
*
- * @param datum string value
- * @param name field name in schema.
- * @param castType data type to cast `datum` into.
- * @param nullable nullability for the field.
+ * @param schema schema that contains data types to cast the given value into.
* @param options CSV options.
*/
- def castTo(
+ def makeConverters(
+ schema: StructType,
+ options: CSVOptions = CSVOptions()): Array[ValueConverter] = {
+ schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
+ }
+
+ /**
+ * Create a converter which converts the string value to a value according to a desired type.
+ */
+ def makeConverter(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean = true,
+ options: CSVOptions = CSVOptions()): ValueConverter = dataType match {
+ case _: ByteType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(_.toByte)
+
+ case _: ShortType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(_.toShort)
+
+ case _: IntegerType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(_.toInt)
+
+ case _: LongType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(_.toLong)
+
+ case _: FloatType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options) {
+ case options.nanValue => Float.NaN
+ case options.negativeInf => Float.NegativeInfinity
+ case options.positiveInf => Float.PositiveInfinity
+ case datum =>
+ Try(datum.toFloat)
+ .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue())
+ }
+
+ case _: DoubleType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options) {
+ case options.nanValue => Double.NaN
+ case options.negativeInf => Double.NegativeInfinity
+ case options.positiveInf => Double.PositiveInfinity
+ case datum =>
+ Try(datum.toDouble)
+ .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue())
+ }
+
+ case _: BooleanType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(_.toBoolean)
+
+ case dt: DecimalType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options) { datum =>
+ val value = new BigDecimal(datum.replaceAll(",", ""))
+ Decimal(value, dt.precision, dt.scale)
+ }
+
+ case _: TimestampType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options) { datum =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.
+ Try(options.timestampFormat.parse(datum).getTime * 1000L)
+ .getOrElse {
+ // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility.
+ DateTimeUtils.stringToTime(datum).getTime * 1000L
+ }
+ }
+
+ case _: DateType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options) { datum =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.x
+ Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
+ .getOrElse {
+ // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility.
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+ }
+ }
+
+ case _: StringType => (d: String) =>
+ nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_))
+
+ case udt: UserDefinedType[_] => (datum: String) =>
+ makeConverter(name, udt.sqlType, nullable, options)
+
+ case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}")
+ }
+
+ private def nullSafeDatum(
datum: String,
name: String,
- castType: DataType,
- nullable: Boolean = true,
- options: CSVOptions = CSVOptions()): Any = {
-
- // datum can be null if the number of fields found is less than the length of the schema
+ nullable: Boolean,
+ options: CSVOptions)(converter: ValueConverter): Any = {
if (datum == options.nullValue || datum == null) {
if (!nullable) {
throw new RuntimeException(s"null value found but field $name is not nullable.")
}
null
} else {
- castType match {
- case _: ByteType => datum.toByte
- case _: ShortType => datum.toShort
- case _: IntegerType => datum.toInt
- case _: LongType => datum.toLong
- case _: FloatType =>
- datum match {
- case options.nanValue => Float.NaN
- case options.negativeInf => Float.NegativeInfinity
- case options.positiveInf => Float.PositiveInfinity
- case _ =>
- Try(datum.toFloat)
- .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue())
- }
- case _: DoubleType =>
- datum match {
- case options.nanValue => Double.NaN
- case options.negativeInf => Double.NegativeInfinity
- case options.positiveInf => Double.PositiveInfinity
- case _ =>
- Try(datum.toDouble)
- .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue())
- }
- case _: BooleanType => datum.toBoolean
- case dt: DecimalType =>
- val value = new BigDecimal(datum.replaceAll(",", ""))
- Decimal(value, dt.precision, dt.scale)
- case _: TimestampType =>
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.
- Try(options.timestampFormat.parse(datum).getTime * 1000L)
- .getOrElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- DateTimeUtils.stringToTime(datum).getTime * 1000L
- }
- case _: DateType =>
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.x
- Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
- .getOrElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
- }
- case _: StringType => UTF8String.fromString(datum)
- case udt: UserDefinedType[_] => castTo(datum, name, udt.sqlType, nullable, options)
- case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
- }
+ converter.apply(datum)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index e4ce7a94be..23c07eb630 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -69,26 +69,26 @@ object CSVRelation extends Logging {
schema: StructType,
requiredColumns: Array[String],
params: CSVOptions): (Array[String], Int) => Option[InternalRow] = {
- val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
// If `dropMalformed` is enabled, then it needs to parse all the values
// so that we can decide which row is malformed.
- requiredFields ++ schemaFields.filterNot(requiredFields.contains(_))
+ requiredFields ++ schema.filterNot(requiredFields.contains(_))
} else {
requiredFields
}
val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
- schemaFields.zipWithIndex.filter {
- case (field, _) => safeRequiredFields.contains(field)
- }.foreach {
- case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
+ schema.zipWithIndex.filter { case (field, _) =>
+ safeRequiredFields.contains(field)
+ }.foreach { case (field, index) =>
+ safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}
val requiredSize = requiredFields.length
val row = new GenericInternalRow(requiredSize)
+ val converters = CSVTypeCast.makeConverters(schema, params)
(tokens: Array[String], numMalformedRows) => {
- if (params.dropMalformed && schemaFields.length != tokens.length) {
+ if (params.dropMalformed && schema.length != tokens.length) {
if (numMalformedRows < params.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
}
@@ -98,14 +98,14 @@ object CSVRelation extends Logging {
"found on this partition. Malformed records from now on will not be logged.")
}
None
- } else if (params.failFast && schemaFields.length != tokens.length) {
+ } else if (params.failFast && schema.length != tokens.length) {
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(params.delimiter.toString)}")
} else {
- val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) {
- tokens ++ new Array[String](schemaFields.length - tokens.length)
- } else if (params.permissive && schemaFields.length < tokens.length) {
- tokens.take(schemaFields.length)
+ val indexSafeTokens = if (params.permissive && schema.length > tokens.length) {
+ tokens ++ new Array[String](schema.length - tokens.length)
+ } else if (params.permissive && schema.length < tokens.length) {
+ tokens.take(schema.length)
} else {
tokens
}
@@ -114,20 +114,14 @@ object CSVRelation extends Logging {
var subIndex: Int = 0
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
- val field = schemaFields(index)
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
- val value = CSVTypeCast.castTo(
- indexSafeTokens(index),
- field.name,
- field.dataType,
- field.nullable,
- params)
+ val value = converters(index).apply(indexSafeTokens(index))
if (subIndex < requiredSize) {
row(subIndex) = value
}
- subIndex = subIndex + 1
+ subIndex += 1
}
Some(row)
} catch {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index 46333d1213..ffd3d260bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -36,7 +36,7 @@ class CSVTypeCastSuite extends SparkFunSuite {
stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
val decimalValue = new BigDecimal(decimalVal.toString)
- assert(CSVTypeCast.castTo(strVal, "_1", decimalType) ===
+ assert(CSVTypeCast.makeConverter("_1", decimalType).apply(strVal) ===
Decimal(decimalValue, decimalType.precision, decimalType.scale))
}
}
@@ -66,92 +66,81 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("Nullable types are handled") {
- assertNull(
- CSVTypeCast.castTo("-", "_1", ByteType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", ShortType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", LongType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", FloatType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", DoubleType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", BooleanType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", TimestampType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", DateType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo("-", "_1", StringType, nullable = true, CSVOptions("nullValue", "-")))
- assertNull(
- CSVTypeCast.castTo(null, "_1", IntegerType, nullable = true, CSVOptions("nullValue", "-")))
-
- // casting a null to not nullable field should throw an exception.
- var message = intercept[RuntimeException] {
- CSVTypeCast.castTo(null, "_1", IntegerType, nullable = false, CSVOptions("nullValue", "-"))
- }.getMessage
- assert(message.contains("null value found but field _1 is not nullable."))
-
- message = intercept[RuntimeException] {
- CSVTypeCast.castTo("-", "_1", StringType, nullable = false, CSVOptions("nullValue", "-"))
- }.getMessage
- assert(message.contains("null value found but field _1 is not nullable."))
- }
-
- test("String type should also respect `nullValue`") {
- assertNull(
- CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions()))
+ val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
+ BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType)
+
+ // Nullable field with nullValue option.
+ types.foreach { t =>
+ // Tests that a custom nullValue.
+ val converter =
+ CSVTypeCast.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-"))
+ assertNull(converter.apply("-"))
+ assertNull(converter.apply(null))
+
+ // Tests that the default nullValue is empty string.
+ assertNull(CSVTypeCast.makeConverter("_1", t, nullable = true).apply(""))
+ }
- assert(
- CSVTypeCast.castTo("", "_1", StringType, nullable = true, CSVOptions("nullValue", "null")) ==
- UTF8String.fromString(""))
- assert(
- CSVTypeCast.castTo("", "_1", StringType, nullable = false, CSVOptions("nullValue", "null")) ==
- UTF8String.fromString(""))
+ // Not nullable field with nullValue option.
+ types.foreach { t =>
+ // Casts a null to not nullable field should throw an exception.
+ val converter =
+ CSVTypeCast.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-"))
+ var message = intercept[RuntimeException] {
+ converter.apply("-")
+ }.getMessage
+ assert(message.contains("null value found but field _1 is not nullable."))
+ message = intercept[RuntimeException] {
+ converter.apply(null)
+ }.getMessage
+ assert(message.contains("null value found but field _1 is not nullable."))
+ }
- assertNull(
- CSVTypeCast.castTo(null, "_1", StringType, nullable = true, CSVOptions("nullValue", "null")))
+ // If nullValue is different with empty string, then, empty string should not be casted into
+ // null.
+ Seq(true, false).foreach { b =>
+ val converter =
+ CSVTypeCast.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null"))
+ assert(converter.apply("") == UTF8String.fromString(""))
+ }
}
test("Throws exception for empty string with non null type") {
val exception = intercept[RuntimeException]{
- CSVTypeCast.castTo("", "_1", IntegerType, nullable = false, CSVOptions())
+ CSVTypeCast.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("")
}
assert(exception.getMessage.contains("null value found but field _1 is not nullable."))
}
test("Types are cast correctly") {
- assert(CSVTypeCast.castTo("10", "_1", ByteType) == 10)
- assert(CSVTypeCast.castTo("10", "_1", ShortType) == 10)
- assert(CSVTypeCast.castTo("10", "_1", IntegerType) == 10)
- assert(CSVTypeCast.castTo("10", "_1", LongType) == 10)
- assert(CSVTypeCast.castTo("1.00", "_1", FloatType) == 1.0)
- assert(CSVTypeCast.castTo("1.00", "_1", DoubleType) == 1.0)
- assert(CSVTypeCast.castTo("true", "_1", BooleanType) == true)
+ assert(CSVTypeCast.makeConverter("_1", ByteType).apply("10") == 10)
+ assert(CSVTypeCast.makeConverter("_1", ShortType).apply("10") == 10)
+ assert(CSVTypeCast.makeConverter("_1", IntegerType).apply("10") == 10)
+ assert(CSVTypeCast.makeConverter("_1", LongType).apply("10") == 10)
+ assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1.00") == 1.0)
+ assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1.00") == 1.0)
+ assert(CSVTypeCast.makeConverter("_1", BooleanType).apply("true") == true)
val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm")
val customTimestamp = "31/01/2015 00:00"
val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime
val castedTimestamp =
- CSVTypeCast.castTo(customTimestamp, "_1", TimestampType, nullable = true, timestampsOptions)
+ CSVTypeCast.makeConverter("_1", TimestampType, nullable = true, timestampsOptions)
+ .apply(customTimestamp)
assert(castedTimestamp == expectedTime * 1000L)
val customDate = "31/01/2015"
val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy")
val expectedDate = dateOptions.dateFormat.parse(customDate).getTime
val castedDate =
- CSVTypeCast.castTo(customTimestamp, "_1", DateType, nullable = true, dateOptions)
+ CSVTypeCast.makeConverter("_1", DateType, nullable = true, dateOptions)
+ .apply(customTimestamp)
assert(castedDate == DateTimeUtils.millisToDays(expectedDate))
val timestamp = "2015-01-01 00:00:00"
- assert(CSVTypeCast.castTo(timestamp, "_1", TimestampType) ==
+ assert(CSVTypeCast.makeConverter("_1", TimestampType).apply(timestamp) ==
DateTimeUtils.stringToTime(timestamp).getTime * 1000L)
- assert(CSVTypeCast.castTo("2015-01-01", "_1", DateType) ==
+ assert(CSVTypeCast.makeConverter("_1", DateType).apply("2015-01-01") ==
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
}
@@ -159,16 +148,18 @@ class CSVTypeCastSuite extends SparkFunSuite {
val originalLocale = Locale.getDefault
try {
Locale.setDefault(new Locale("fr", "FR"))
- assert(CSVTypeCast.castTo("1,00", "_1", FloatType) == 100.0) // Would parse as 1.0 in fr-FR
- assert(CSVTypeCast.castTo("1,00", "_1", DoubleType) == 100.0)
+ // Would parse as 1.0 in fr-FR
+ assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1,00") == 100.0)
+ assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1,00") == 100.0)
} finally {
Locale.setDefault(originalLocale)
}
}
test("Float NaN values are parsed correctly") {
- val floatVal: Float = CSVTypeCast.castTo(
- "nn", "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float]
+ val floatVal: Float = CSVTypeCast.makeConverter(
+ "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn")
+ ).apply("nn").asInstanceOf[Float]
// Java implements the IEEE-754 floating point standard which guarantees that any comparison
// against NaN will return false (except != which returns true)
@@ -176,34 +167,37 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("Double NaN values are parsed correctly") {
- val doubleVal: Double = CSVTypeCast.castTo(
- "-", "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double]
+ val doubleVal: Double = CSVTypeCast.makeConverter(
+ "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-")
+ ).apply("-").asInstanceOf[Double]
assert(doubleVal.isNaN)
}
test("Float infinite values can be parsed") {
- val floatVal1 = CSVTypeCast.castTo(
- "max", "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float]
+ val floatVal1 = CSVTypeCast.makeConverter(
+ "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max")
+ ).apply("max").asInstanceOf[Float]
assert(floatVal1 == Float.NegativeInfinity)
- val floatVal2 = CSVTypeCast.castTo(
- "max", "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float]
+ val floatVal2 = CSVTypeCast.makeConverter(
+ "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max")
+ ).apply("max").asInstanceOf[Float]
assert(floatVal2 == Float.PositiveInfinity)
}
test("Double infinite values can be parsed") {
- val doubleVal1 = CSVTypeCast.castTo(
- "max", "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
- ).asInstanceOf[Double]
+ val doubleVal1 = CSVTypeCast.makeConverter(
+ "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
+ ).apply("max").asInstanceOf[Double]
assert(doubleVal1 == Double.NegativeInfinity)
- val doubleVal2 = CSVTypeCast.castTo(
- "max", "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
- ).asInstanceOf[Double]
+ val doubleVal2 = CSVTypeCast.makeConverter(
+ "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
+ ).apply("max").asInstanceOf[Double]
assert(doubleVal2 == Double.PositiveInfinity)
}