aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-04-29 22:52:21 -0700
committerReynold Xin <rxin@databricks.com>2016-04-29 22:52:21 -0700
commit4bac703eb9dcc286d6b89630cf433f95b63a4a1f (patch)
tree62fa102cdfcff3b96a5b77c4393cf8101f42b791 /sql
parentac41fc648de584f08863313fbac0c5bb6fc6a65e (diff)
downloadspark-4bac703eb9dcc286d6b89630cf433f95b63a4a1f.tar.gz
spark-4bac703eb9dcc286d6b89630cf433f95b63a4a1f.tar.bz2
spark-4bac703eb9dcc286d6b89630cf433f95b63a4a1f.zip
[SPARK-13667][SQL] Support for specifying custom date format for date and timestamp types at CSV datasource.
## What changes were proposed in this pull request? This PR adds the support to specify custom date format for `DateType` and `TimestampType`. For `TimestampType`, this uses the given format to infer schema and also to convert the values For `DateType`, this uses the given format to convert the values. If the `dateFormat` is not given, then it works with `DateTimeUtils.stringToTime()` for backwords compatibility. When it's given, then it uses `SimpleDateFormat` for parsing data. In addition, `IntegerType`, `DoubleType` and `LongType` have a higher priority than `TimestampType` in type inference. This means even if the given format is `yyyy` or `yyyy.MM`, it will be inferred as `IntegerType` or `DoubleType`. Since it is type inference, I think it is okay to give such precedences. In addition, I renamed `csv.CSVInferSchema` to `csv.InferSchema` as JSON datasource has `json.InferSchema`. Although they have the same names, I did this because I thought the parent package name can still differentiate each. Accordingly, the suite name was also changed from `CSVInferSchemaSuite` to `InferSchemaSuite`. ## How was this patch tested? unit tests are used and `./dev/run_tests` for coding style tests. Author: hyukjinkwon <gurwls223@gmail.com> Closes #11550 from HyukjinKwon/SPARK-13667.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala2
-rw-r--r--sql/core/src/test/resources/dates.csv4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala78
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala11
8 files changed, 173 insertions, 66 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index ea843a1013..a26a8084b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.text.NumberFormat
+import java.text.{NumberFormat, SimpleDateFormat}
import java.util.Locale
import scala.util.control.Exception._
@@ -41,11 +41,10 @@ private[csv] object CSVInferSchema {
def infer(
tokenRdd: RDD[Array[String]],
header: Array[String],
- nullValue: String = ""): StructType = {
-
+ options: CSVOptions): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
- tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
+ tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
@@ -58,11 +57,11 @@ private[csv] object CSVInferSchema {
StructType(structFields)
}
- private def inferRowType(nullValue: String)
+ private def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
- rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
+ rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
i+=1
}
rowSoFar
@@ -78,17 +77,17 @@ private[csv] object CSVInferSchema {
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
- def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = {
- if (field == null || field.isEmpty || field == nullValue) {
+ def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
+ if (field == null || field.isEmpty || field == options.nullValue) {
typeSoFar
} else {
typeSoFar match {
- case NullType => tryParseInteger(field)
- case IntegerType => tryParseInteger(field)
- case LongType => tryParseLong(field)
- case DoubleType => tryParseDouble(field)
- case TimestampType => tryParseTimestamp(field)
- case BooleanType => tryParseBoolean(field)
+ case NullType => tryParseInteger(field, options)
+ case IntegerType => tryParseInteger(field, options)
+ case LongType => tryParseLong(field, options)
+ case DoubleType => tryParseDouble(field, options)
+ case TimestampType => tryParseTimestamp(field, options)
+ case BooleanType => tryParseBoolean(field, options)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
@@ -96,35 +95,49 @@ private[csv] object CSVInferSchema {
}
}
- private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
- IntegerType
- } else {
- tryParseLong(field)
+ private def tryParseInteger(field: String, options: CSVOptions): DataType = {
+ if ((allCatch opt field.toInt).isDefined) {
+ IntegerType
+ } else {
+ tryParseLong(field, options)
+ }
}
- private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
- LongType
- } else {
- tryParseDouble(field)
+ private def tryParseLong(field: String, options: CSVOptions): DataType = {
+ if ((allCatch opt field.toLong).isDefined) {
+ LongType
+ } else {
+ tryParseDouble(field, options)
+ }
}
- private def tryParseDouble(field: String): DataType = {
+ private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
} else {
- tryParseTimestamp(field)
+ tryParseTimestamp(field, options)
}
}
- def tryParseTimestamp(field: String): DataType = {
- if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
- TimestampType
+ private def tryParseTimestamp(field: String, options: CSVOptions): DataType = {
+ if (options.dateFormat != null) {
+ // This case infers a custom `dataFormat` is set.
+ if ((allCatch opt options.dateFormat.parse(field)).isDefined) {
+ TimestampType
+ } else {
+ tryParseBoolean(field, options)
+ }
} else {
- tryParseBoolean(field)
+ // We keep this for backwords competibility.
+ if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
+ TimestampType
+ } else {
+ tryParseBoolean(field, options)
+ }
}
}
- def tryParseBoolean(field: String): DataType = {
+ private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
@@ -177,7 +190,8 @@ private[csv] object CSVTypeCast {
datum: String,
castType: DataType,
nullable: Boolean = true,
- nullValue: String = ""): Any = {
+ nullValue: String = "",
+ dateFormat: SimpleDateFormat = null): Any = {
if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) {
null
@@ -195,12 +209,16 @@ private[csv] object CSVTypeCast {
case dt: DecimalType =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
- // TODO(hossein): would be good to support other common timestamp formats
+ case _: TimestampType if dateFormat != null =>
+ // This one will lose microseconds parts.
+ // See https://issues.apache.org/jira/browse/SPARK-10681.
+ dateFormat.parse(datum).getTime * 1000L
case _: TimestampType =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
DateTimeUtils.stringToTime(datum).getTime * 1000L
- // TODO(hossein): would be good to support other common date formats
+ case _: DateType if dateFormat != null =>
+ DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime)
case _: DateType =>
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
case _: StringType => UTF8String.fromString(datum)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 80a0ad7856..b87d19f7cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.StandardCharsets
+import java.text.SimpleDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}
@@ -94,6 +95,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
name.map(CompressionCodecs.getCodecClassName)
}
+ // Share date format object as it is expensive to parse date pattern.
+ val dateFormat: SimpleDateFormat = {
+ val dateFormat = parameters.get("dateFormat")
+ dateFormat.map(new SimpleDateFormat(_)).orNull
+ }
+
val maxColumns = getInt("maxColumns", 20480)
val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index ed40cd0c81..9a723630de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -99,7 +99,8 @@ object CSVRelation extends Logging {
indexSafeTokens(index),
field.dataType,
field.nullable,
- params.nullValue)
+ params.nullValue,
+ params.dateFormat)
if (subIndex < requiredSize) {
row(subIndex) = value
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 8ca105d923..75143e609a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -68,7 +68,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
- CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue)
+ CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
diff --git a/sql/core/src/test/resources/dates.csv b/sql/core/src/test/resources/dates.csv
new file mode 100644
index 0000000000..9ee99c31b3
--- /dev/null
+++ b/sql/core/src/test/resources/dates.csv
@@ -0,0 +1,4 @@
+date
+26/08/2015 18:00
+27/10/2014 18:30
+28/01/2016 20:00
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 23d422635b..daf85be56f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -17,45 +17,58 @@
package org.apache.spark.sql.execution.datasources.csv
+import java.text.SimpleDateFormat
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
class CSVInferSchemaSuite extends SparkFunSuite {
test("String fields types are inferred correctly from null types") {
- assert(CSVInferSchema.inferField(NullType, "") == NullType)
- assert(CSVInferSchema.inferField(NullType, null) == NullType)
- assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType)
- assert(CSVInferSchema.inferField(NullType, "60") == IntegerType)
- assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
- assert(CSVInferSchema.inferField(NullType, "test") == StringType)
- assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
- assert(CSVInferSchema.inferField(NullType, "True") == BooleanType)
- assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType)
+ val options = new CSVOptions(Map.empty[String, String])
+ assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
+ assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
+ assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType)
+ assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType)
+ assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType)
+ assert(CSVInferSchema.inferField(NullType, "test", options) == StringType)
+ assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType)
+ assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
+ assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType)
}
test("String fields types are inferred correctly from other types") {
- assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType)
- assert(CSVInferSchema.inferField(LongType, "test") == StringType)
- assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType)
- assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType)
- assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
- assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
- assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
- assert(CSVInferSchema.inferField(LongType, "True") == BooleanType)
- assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType)
- assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType)
+ val options = new CSVOptions(Map.empty[String, String])
+ assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
+ assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
+ assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType)
+ assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType)
+ assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType)
+ assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType)
+ assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType)
+ assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
+ assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType)
+ assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType)
+ }
+
+ test("Timestamp field types are inferred correctly via custom data format") {
+ var options = new CSVOptions(Map("dateFormat" -> "yyyy-mm"))
+ assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
+ options = new CSVOptions(Map("dateFormat" -> "yyyy"))
+ assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType)
}
test("Timestamp field types are inferred correctly from other types") {
- assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType)
- assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType)
- assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
+ val options = new CSVOptions(Map.empty[String, String])
+ assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType)
+ assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType)
+ assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType)
}
test("Boolean fields types are inferred correctly from other types") {
- assert(CSVInferSchema.inferField(LongType, "Fale") == StringType)
- assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType)
+ val options = new CSVOptions(Map.empty[String, String])
+ assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
+ assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType)
}
test("Type arrays are merged to highest common type") {
@@ -71,13 +84,16 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}
test("Null fields are handled properly when a nullValue is specified") {
- assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType)
- assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType)
- assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType)
- assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
- assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
- assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
- assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType)
+ var options = new CSVOptions(Map("nullValue" -> "null"))
+ assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
+ assert(CSVInferSchema.inferField(StringType, "null", options) == StringType)
+ assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)
+
+ options = new CSVOptions(Map("nullValue" -> "\\N"))
+ assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType)
+ assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
+ assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
+ assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
}
test("Merging Nulltypes should yield Nulltype.") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index ceda920ddc..8847c7632f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources.csv
import java.io.File
import java.nio.charset.UnsupportedCharsetException
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
+import java.text.SimpleDateFormat
import scala.collection.JavaConverters._
@@ -45,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"
+ private val datesFile = "dates.csv"
private val unescapedQuotesFile = "unescaped-quotes.csv"
private def testFile(fileName: String): String = {
@@ -367,6 +369,54 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(results.toSeq.map(_.toSeq) === expected)
}
+ test("inferring timestamp types via custom date format") {
+ val options = Map(
+ "header" -> "true",
+ "inferSchema" -> "true",
+ "dateFormat" -> "dd/MM/yyyy hh:mm")
+ val results = sqlContext.read
+ .format("csv")
+ .options(options)
+ .load(testFile(datesFile))
+ .select("date")
+ .collect()
+
+ val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
+ val expected =
+ Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)),
+ Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)),
+ Seq(new Timestamp(dateFormat.parse("28/01/2016 20:00").getTime)))
+ assert(results.toSeq.map(_.toSeq) === expected)
+ }
+
+ test("load date types via custom date format") {
+ val customSchema = new StructType(Array(StructField("date", DateType, true)))
+ val options = Map(
+ "header" -> "true",
+ "inferSchema" -> "false",
+ "dateFormat" -> "dd/MM/yyyy hh:mm")
+ val results = sqlContext.read
+ .format("csv")
+ .options(options)
+ .schema(customSchema)
+ .load(testFile(datesFile))
+ .select("date")
+ .collect()
+
+ val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
+ val expected = Seq(
+ new Date(dateFormat.parse("26/08/2015 18:00").getTime),
+ new Date(dateFormat.parse("27/10/2014 18:30").getTime),
+ new Date(dateFormat.parse("28/01/2016 20:00").getTime))
+ val dates = results.toSeq.map(_.toSeq.head)
+ expected.zip(dates).foreach {
+ case (expectedDate, date) =>
+ // As it truncates the hours, minutes and etc., we only check
+ // if the dates (days, months and years) are the same via `toString()`.
+ assert(expectedDate.toString === date.toString)
+ }
+ }
+
test("setting comment to null disables comment support") {
val results = sqlContext.read
.format("csv")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index 5702a1b4ea..8b59bc148f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
+import java.sql.{Date, Timestamp}
+import java.text.SimpleDateFormat
import java.util.Locale
import org.apache.spark.SparkFunSuite
@@ -87,6 +89,15 @@ class CSVTypeCastSuite extends SparkFunSuite {
assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0)
assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", BooleanType) == true)
+
+ val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
+ val customTimestamp = "31/01/2015 00:00"
+ val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime
+ assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = dateFormat)
+ == expectedTime * 1000L)
+ assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = dateFormat) ==
+ DateTimeUtils.millisToDays(expectedTime))
+
val timestamp = "2015-01-01 00:00:00"
assert(CSVTypeCast.castTo(timestamp, TimestampType) ==
DateTimeUtils.stringToTime(timestamp).getTime * 1000L)