From 92024797a4fad594b5314f3f3be5c6be2434de8a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Mar 2016 23:31:46 -0700 Subject: [SPARK-13899][SQL] Produce InternalRow instead of external Row at CSV data source ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13899 This PR makes CSV data source produce `InternalRow` instead of `Row`. Basically, this resembles JSON data source. It uses the same codes for casting. ## How was this patch tested? Unit tests were used within IDE and code style was checked by `./dev/run_tests`. Author: hyukjinkwon Closes #11717 from HyukjinKwon/SPARK-13899. --- .../execution/datasources/csv/CSVInferSchema.scala | 19 +++++++++++++------ .../sql/execution/datasources/csv/CSVRelation.scala | 15 +++++++++++---- .../sql/execution/datasources/csv/DefaultSource.scala | 13 +++++++------ .../execution/datasources/csv/CSVTypeCastSuite.scala | 17 +++++++++++------ 4 files changed, 42 insertions(+), 22 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index edead9b21b..797f740dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.sql.{Date, Timestamp} import java.text.NumberFormat import java.util.Locale @@ -27,7 +26,9 @@ import scala.util.Try import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[csv] object CSVInferSchema { @@ -116,7 +117,7 @@ private[csv] object CSVInferSchema { } def tryParseTimestamp(field: String): DataType = { - if ((allCatch opt Timestamp.valueOf(field)).isDefined) { + if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { TimestampType } else { tryParseBoolean(field) @@ -191,12 +192,18 @@ private[csv] object CSVTypeCast { case _: DoubleType => Try(datum.toDouble) .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + case dt: DecimalType => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) // TODO(hossein): would be good to support other common timestamp formats - case _: TimestampType => Timestamp.valueOf(datum) + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + DateTimeUtils.stringToTime(datum).getTime * 1000L // TODO(hossein): would be good to support other common date formats - case _: DateType => Date.valueOf(datum) - case _: StringType => datum + case _: DateType => + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + case _: StringType => UTF8String.fromString(datum) case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index c96a508cf1..eeb56f7c6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -54,7 +55,7 @@ object CSVRelation extends Logging { requiredColumns: Array[String], inputs: Seq[FileStatus], sqlContext: SQLContext, - params: CSVOptions): RDD[Row] = { + params: CSVOptions): RDD[InternalRow] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields @@ -71,8 +72,8 @@ object CSVRelation extends Logging { }.foreach { case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index } - val rowArray = new Array[Any](safeRequiredIndices.length) val requiredSize = requiredFields.length + val row = new GenericMutableRow(requiredSize) tokenizedRDD.flatMap { tokens => if (params.dropMalformed && schemaFields.length != tokens.length) { logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") @@ -94,14 +95,20 @@ object CSVRelation extends Logging { while (subIndex < safeRequiredIndices.length) { index = safeRequiredIndices(subIndex) val field = schemaFields(index) - rowArray(subIndex) = CSVTypeCast.castTo( + // It anyway needs to try to parse since it decides if this row is malformed + // or not after trying to cast in `DROPMALFORMED` mode even if the casted + // value is not stored in the row. + val value = CSVTypeCast.castTo( indexSafeTokens(index), field.dataType, field.nullable, params.nullValue) + if (subIndex < requiredSize) { + row(subIndex) = value + } subIndex = subIndex + 1 } - Some(Row.fromSeq(rowArray.take(requiredSize))) + Some(row) } catch { case NonFatal(e) if params.dropMalformed => logWarning("Parse exception. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index a5f94262ff..54e4c1a2c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.datasources.CompressionCodecs import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -113,13 +113,14 @@ class DefaultSource extends FileFormat with DataSourceRegister { val pathsString = csvFiles.map(_.getPath.toUri.toString) val header = dataSchema.fields.map(_.name) val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString) - val external = CSVRelation.parseCsv( + val rows = CSVRelation.parseCsv( tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions) - // TODO: Generate InternalRow in parseCsv - val outputSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) - val encoder = RowEncoder(outputSchema) - external.map(encoder.toRow) + val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) + rows.mapPartitions { iterator => + val unsafeProjection = UnsafeProjection.create(requiredDataSchema) + iterator.map(unsafeProjection) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index c28a25057e..5702a1b4ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.sql.{Date, Timestamp} import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CSVTypeCastSuite extends SparkFunSuite { @@ -32,7 +33,9 @@ class CSVTypeCastSuite extends SparkFunSuite { val decimalType = new DecimalType() stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => - assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString)) + val decimalValue = new BigDecimal(decimalVal.toString) + assert(CSVTypeCast.castTo(strVal, decimalType) === + Decimal(decimalValue, decimalType.precision, decimalType.scale)) } } @@ -65,8 +68,8 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("String type should always return the same as the input") { - assert(CSVTypeCast.castTo("", StringType, nullable = true) == "") - assert(CSVTypeCast.castTo("", StringType, nullable = false) == "") + assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString("")) + assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString("")) } test("Throws exception for empty string with non null type") { @@ -85,8 +88,10 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) val timestamp = "2015-01-01 00:00:00" - assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) - assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + assert(CSVTypeCast.castTo(timestamp, TimestampType) == + DateTimeUtils.stringToTime(timestamp).getTime * 1000L) + assert(CSVTypeCast.castTo("2015-01-01", DateType) == + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } test("Float and Double Types are cast correctly with Locale") { -- cgit v1.2.3