aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-03-15 23:31:46 -0700
committerReynold Xin <rxin@databricks.com>2016-03-15 23:31:46 -0700
commit92024797a4fad594b5314f3f3be5c6be2434de8a (patch)
treef0c4bfa13a48aca1ce2bbc2ccc2b0ea9d6cd4fb9 /sql/core/src
parent3c578c594ebe124eb1a1b21e41a64e7661b04de9 (diff)
downloadspark-92024797a4fad594b5314f3f3be5c6be2434de8a.tar.gz
spark-92024797a4fad594b5314f3f3be5c6be2434de8a.tar.bz2
spark-92024797a4fad594b5314f3f3be5c6be2434de8a.zip
[SPARK-13899][SQL] Produce InternalRow instead of external Row at CSV data source
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13899 This PR makes CSV data source produce `InternalRow` instead of `Row`. Basically, this resembles JSON data source. It uses the same codes for casting. ## How was this patch tested? Unit tests were used within IDE and code style was checked by `./dev/run_tests`. Author: hyukjinkwon <gurwls223@gmail.com> Closes #11717 from HyukjinKwon/SPARK-13899.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala17
4 files changed, 42 insertions, 22 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index edead9b21b..797f740dc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.sql.{Date, Timestamp}
import java.text.NumberFormat
import java.util.Locale
@@ -27,7 +26,9 @@ import scala.util.Try
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
private[csv] object CSVInferSchema {
@@ -116,7 +117,7 @@ private[csv] object CSVInferSchema {
}
def tryParseTimestamp(field: String): DataType = {
- if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
+ if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
@@ -191,12 +192,18 @@ private[csv] object CSVTypeCast {
case _: DoubleType => Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
- case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
+ case dt: DecimalType =>
+ val value = new BigDecimal(datum.replaceAll(",", ""))
+ Decimal(value, dt.precision, dt.scale)
// TODO(hossein): would be good to support other common timestamp formats
- case _: TimestampType => Timestamp.valueOf(datum)
+ case _: TimestampType =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.
+ DateTimeUtils.stringToTime(datum).getTime * 1000L
// TODO(hossein): would be good to support other common date formats
- case _: DateType => Date.valueOf(datum)
- case _: StringType => datum
+ case _: DateType =>
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+ case _: StringType => UTF8String.fromString(datum)
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index c96a508cf1..eeb56f7c6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -29,6 +29,7 @@ import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -54,7 +55,7 @@ object CSVRelation extends Logging {
requiredColumns: Array[String],
inputs: Seq[FileStatus],
sqlContext: SQLContext,
- params: CSVOptions): RDD[Row] = {
+ params: CSVOptions): RDD[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
@@ -71,8 +72,8 @@ object CSVRelation extends Logging {
}.foreach {
case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}
- val rowArray = new Array[Any](safeRequiredIndices.length)
val requiredSize = requiredFields.length
+ val row = new GenericMutableRow(requiredSize)
tokenizedRDD.flatMap { tokens =>
if (params.dropMalformed && schemaFields.length != tokens.length) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
@@ -94,14 +95,20 @@ object CSVRelation extends Logging {
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
val field = schemaFields(index)
- rowArray(subIndex) = CSVTypeCast.castTo(
+ // It anyway needs to try to parse since it decides if this row is malformed
+ // or not after trying to cast in `DROPMALFORMED` mode even if the casted
+ // value is not stored in the row.
+ val value = CSVTypeCast.castTo(
indexSafeTokens(index),
field.dataType,
field.nullable,
params.nullValue)
+ if (subIndex < requiredSize) {
+ row(subIndex) = value
+ }
subIndex = subIndex + 1
}
- Some(Row.fromSeq(rowArray.take(requiredSize)))
+ Some(row)
} catch {
case NonFatal(e) if params.dropMalformed =>
logWarning("Parse exception. " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index a5f94262ff..54e4c1a2c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.execution.datasources.CompressionCodecs
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -113,13 +113,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val pathsString = csvFiles.map(_.getPath.toUri.toString)
val header = dataSchema.fields.map(_.name)
val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString)
- val external = CSVRelation.parseCsv(
+ val rows = CSVRelation.parseCsv(
tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions)
- // TODO: Generate InternalRow in parseCsv
- val outputSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
- val encoder = RowEncoder(outputSchema)
- external.map(encoder.toRow)
+ val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
+ rows.mapPartitions { iterator =>
+ val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
+ iterator.map(unsafeProjection)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index c28a25057e..5702a1b4ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.sql.{Date, Timestamp}
import java.util.Locale
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class CSVTypeCastSuite extends SparkFunSuite {
@@ -32,7 +33,9 @@ class CSVTypeCastSuite extends SparkFunSuite {
val decimalType = new DecimalType()
stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
- assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))
+ val decimalValue = new BigDecimal(decimalVal.toString)
+ assert(CSVTypeCast.castTo(strVal, decimalType) ===
+ Decimal(decimalValue, decimalType.precision, decimalType.scale))
}
}
@@ -65,8 +68,8 @@ class CSVTypeCastSuite extends SparkFunSuite {
}
test("String type should always return the same as the input") {
- assert(CSVTypeCast.castTo("", StringType, nullable = true) == "")
- assert(CSVTypeCast.castTo("", StringType, nullable = false) == "")
+ assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString(""))
+ assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString(""))
}
test("Throws exception for empty string with non null type") {
@@ -85,8 +88,10 @@ class CSVTypeCastSuite extends SparkFunSuite {
assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", BooleanType) == true)
val timestamp = "2015-01-01 00:00:00"
- assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
- assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
+ assert(CSVTypeCast.castTo(timestamp, TimestampType) ==
+ DateTimeUtils.stringToTime(timestamp).getTime * 1000L)
+ assert(CSVTypeCast.castTo("2015-01-01", DateType) ==
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
}
test("Float and Double Types are cast correctly with Locale") {