aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHossein <hossein@databricks.com>2016-01-15 11:46:46 -0800
committerReynold Xin <rxin@databricks.com>2016-01-15 11:46:46 -0800
commit5f83c6991c95616ecbc2878f8860c69b2826f56c (patch)
tree86dc70e45f1b27b67efec9724632a108d69f2ef0
parentc5e7076da72657ea35a0aa388f8d2e6411d39280 (diff)
downloadspark-5f83c6991c95616ecbc2878f8860c69b2826f56c.tar.gz
spark-5f83c6991c95616ecbc2878f8860c69b2826f56c.tar.bz2
spark-5f83c6991c95616ecbc2878f8860c69b2826f56c.zip
[SPARK-12833][SQL] Initial import of spark-csv
CSV is the most common data format in the "small data" world. It is often the first format people want to try when they see Spark on a single node. Having to rely on a 3rd party component for this leads to poor user experience for new users. This PR merges the popular spark-csv data source package (https://github.com/databricks/spark-csv) with SparkSQL. This is a first PR to bring the functionality to spark 2.0 master. We will complete items outlines in the design document (see JIRA attachment) in follow up pull requests. Author: Hossein <hossein@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #10766 from rxin/csv.
-rw-r--r--.rat-excludes2
-rw-r--r--NOTICE38
-rw-r--r--dev/deps/spark-deps-hadoop-2.21
-rw-r--r--dev/deps/spark-deps-hadoop-2.31
-rw-r--r--dev/deps/spark-deps-hadoop-2.41
-rw-r--r--dev/deps/spark-deps-hadoop-2.61
-rw-r--r--sql/core/pom.xml6
-rw-r--r--sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala227
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala107
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala243
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala298
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala13
-rw-r--r--sql/core/src/test/resources/cars-alternative.csv5
-rw-r--r--sql/core/src/test/resources/cars-null.csv6
-rw-r--r--sql/core/src/test/resources/cars-unbalanced-quotes.csv4
-rw-r--r--sql/core/src/test/resources/cars.csv6
-rw-r--r--sql/core/src/test/resources/cars.tsv4
-rw-r--r--sql/core/src/test/resources/cars_iso-8859-1.csv6
-rw-r--r--sql/core/src/test/resources/comments.csv6
-rw-r--r--sql/core/src/test/resources/disable_comments.csv2
-rw-r--r--sql/core/src/test/resources/empty.csv0
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala71
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala125
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala341
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala98
27 files changed, 1653 insertions, 8 deletions
diff --git a/.rat-excludes b/.rat-excludes
index bf071eba65..a4f316a4aa 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -86,3 +86,5 @@ org.apache.spark.scheduler.SparkHistoryListenerFactory
.*parquet
LZ4BlockInputStream.java
spark-deps-.*
+.*csv
+.*tsv
diff --git a/NOTICE b/NOTICE
index 571f8c2fff..e416aadce9 100644
--- a/NOTICE
+++ b/NOTICE
@@ -610,7 +610,43 @@ Vis.js uses and redistributes the following third-party libraries:
===============================================================================
-The CSS style for the navigation sidebar of the documentation was originally
+The CSS style for the navigation sidebar of the documentation was originally
submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project
is distributed under the 3-Clause BSD license.
===============================================================================
+
+For CSV functionality:
+
+/*
+ * Copyright 2014 Databricks
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Copyright 2015 Ayasdi Inc
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 53034a25d4..fb2e91e1ee 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -184,6 +184,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
+univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index a23e260641..59e4d4f839 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -175,6 +175,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
+univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 6bedbed1e3..e4395c872c 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -176,6 +176,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
+univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 7bfad57b4a..89fd15da7d 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -182,6 +182,7 @@ tachyon-underfs-hdfs-0.8.2.jar
tachyon-underfs-local-0.8.2.jar
tachyon-underfs-s3-0.8.2.jar
uncommons-maths-1.2.2a.jar
+univocity-parsers-1.5.6.jar
unused-1.0.0.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 6db7a8a2dc..31b364f351 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -37,6 +37,12 @@
<dependencies>
<dependency>
+ <groupId>com.univocity</groupId>
+ <artifactId>univocity-parsers</artifactId>
+ <version>1.5.6</version>
+ <type>jar</type>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 1ca2044057..226d59d0ea 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,3 +1,4 @@
+org.apache.spark.sql.execution.datasources.csv.DefaultSource
org.apache.spark.sql.execution.datasources.jdbc.DefaultSource
org.apache.spark.sql.execution.datasources.json.DefaultSource
org.apache.spark.sql.execution.datasources.parquet.DefaultSource
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
new file mode 100644
index 0000000000..0aa4539e60
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.math.BigDecimal
+import java.sql.{Date, Timestamp}
+import java.text.NumberFormat
+import java.util.Locale
+
+import scala.util.control.Exception._
+import scala.util.Try
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.types._
+
+
+private[sql] object CSVInferSchema {
+
+ /**
+ * Similar to the JSON schema inference
+ * 1. Infer type of each row
+ * 2. Merge row types to find common type
+ * 3. Replace any null types with string type
+ * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670]
+ */
+ def apply(
+ tokenRdd: RDD[Array[String]],
+ header: Array[String],
+ nullValue: String = ""): StructType = {
+
+ val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
+ val rootTypes: Array[DataType] =
+ tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
+
+ val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
+ StructField(thisHeader, rootType, nullable = true)
+ }
+
+ StructType(structFields)
+ }
+
+ private def inferRowType(nullValue: String)
+ (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
+ var i = 0
+ while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
+ rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
+ i+=1
+ }
+ rowSoFar
+ }
+
+ private[csv] def mergeRowTypes(
+ first: Array[DataType],
+ second: Array[DataType]): Array[DataType] = {
+
+ first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
+ val tpe = findTightestCommonType(a, b).getOrElse(StringType)
+ tpe match {
+ case _: NullType => StringType
+ case other => other
+ }
+ }
+ }
+
+ /**
+ * Infer type of string field. Given known type Double, and a string "1", there is no
+ * point checking if it is an Int, as the final type must be Double or higher.
+ */
+ private[csv] def inferField(
+ typeSoFar: DataType, field: String, nullValue: String = ""): DataType = {
+ if (field == null || field.isEmpty || field == nullValue) {
+ typeSoFar
+ } else {
+ typeSoFar match {
+ case NullType => tryParseInteger(field)
+ case IntegerType => tryParseInteger(field)
+ case LongType => tryParseLong(field)
+ case DoubleType => tryParseDouble(field)
+ case TimestampType => tryParseTimestamp(field)
+ case StringType => StringType
+ case other: DataType =>
+ throw new UnsupportedOperationException(s"Unexpected data type $other")
+ }
+ }
+ }
+
+ private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
+ IntegerType
+ } else {
+ tryParseLong(field)
+ }
+
+ private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
+ LongType
+ } else {
+ tryParseDouble(field)
+ }
+
+ private def tryParseDouble(field: String): DataType = {
+ if ((allCatch opt field.toDouble).isDefined) {
+ DoubleType
+ } else {
+ tryParseTimestamp(field)
+ }
+ }
+
+ def tryParseTimestamp(field: String): DataType = {
+ if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
+ TimestampType
+ } else {
+ stringType()
+ }
+ }
+
+ // Defining a function to return the StringType constant is necessary in order to work around
+ // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions;
+ // see issue #128 for more details.
+ private def stringType(): DataType = {
+ StringType
+ }
+
+ private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence
+
+ /**
+ * Copied from internal Spark api
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+ */
+ val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
+ case (t1, t2) if t1 == t2 => Some(t1)
+ case (NullType, t1) => Some(t1)
+ case (t1, NullType) => Some(t1)
+
+ // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+ case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
+ val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
+ Some(numericPrecedence(index))
+
+ case _ => None
+ }
+}
+
+object CSVTypeCast {
+
+ /**
+ * Casts given string datum to specified type.
+ * Currently we do not support complex types (ArrayType, MapType, StructType).
+ *
+ * For string types, this is simply the datum. For other types.
+ * For other nullable types, this is null if the string datum is empty.
+ *
+ * @param datum string value
+ * @param castType SparkSQL type
+ */
+ private[csv] def castTo(
+ datum: String,
+ castType: DataType,
+ nullable: Boolean = true,
+ nullValue: String = ""): Any = {
+
+ if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) {
+ null
+ } else {
+ castType match {
+ case _: ByteType => datum.toByte
+ case _: ShortType => datum.toShort
+ case _: IntegerType => datum.toInt
+ case _: LongType => datum.toLong
+ case _: FloatType => Try(datum.toFloat)
+ .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
+ case _: DoubleType => Try(datum.toDouble)
+ .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
+ case _: BooleanType => datum.toBoolean
+ case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
+ // TODO(hossein): would be good to support other common timestamp formats
+ case _: TimestampType => Timestamp.valueOf(datum)
+ // TODO(hossein): would be good to support other common date formats
+ case _: DateType => Date.valueOf(datum)
+ case _: StringType => datum
+ case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
+ }
+ }
+ }
+
+ /**
+ * Helper method that converts string representation of a character to actual character.
+ * It handles some Java escaped strings and throws exception if given string is longer than one
+ * character.
+ *
+ */
+ @throws[IllegalArgumentException]
+ private[csv] def toChar(str: String): Char = {
+ if (str.charAt(0) == '\\') {
+ str.charAt(1)
+ match {
+ case 't' => '\t'
+ case 'r' => '\r'
+ case 'b' => '\b'
+ case 'f' => '\f'
+ case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
+ case '\'' => '\''
+ case 'u' if str == """\u0000""" => '\u0000'
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
+ }
+ } else if (str.length == 1) {
+ str.charAt(0)
+ } else {
+ throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala
new file mode 100644
index 0000000000..ba44121244
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.nio.charset.Charset
+
+import org.apache.spark.Logging
+
+private[sql] case class CSVParameters(parameters: Map[String, String]) extends Logging {
+
+ private def getChar(paramName: String, default: Char): Char = {
+ val paramValue = parameters.get(paramName)
+ paramValue match {
+ case None => default
+ case Some(value) if value.length == 0 => '\0'
+ case Some(value) if value.length == 1 => value.charAt(0)
+ case _ => throw new RuntimeException(s"$paramName cannot be more than one character")
+ }
+ }
+
+ private def getBool(paramName: String, default: Boolean = false): Boolean = {
+ val param = parameters.getOrElse(paramName, default.toString)
+ if (param.toLowerCase() == "true") {
+ true
+ } else if (param.toLowerCase == "false") {
+ false
+ } else {
+ throw new Exception(s"$paramName flag can be true or false")
+ }
+ }
+
+ val delimiter = CSVTypeCast.toChar(parameters.getOrElse("delimiter", ","))
+ val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
+ val charset = parameters.getOrElse("charset", Charset.forName("UTF-8").name())
+
+ val quote = getChar("quote", '\"')
+ val escape = getChar("escape", '\\')
+ val comment = getChar("comment", '\0')
+
+ val headerFlag = getBool("header")
+ val inferSchemaFlag = getBool("inferSchema")
+ val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace")
+ val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace")
+
+ // Limit the number of lines we'll search for a header row that isn't comment-prefixed
+ val MAX_COMMENT_LINES_IN_HEADER = 10
+
+ // Parse mode flags
+ if (!ParseModes.isValidMode(parseMode)) {
+ logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
+ }
+
+ val failFast = ParseModes.isFailFastMode(parseMode)
+ val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
+ val permissive = ParseModes.isPermissiveMode(parseMode)
+
+ val nullValue = parameters.getOrElse("nullValue", "")
+
+ val maxColumns = 20480
+
+ val maxCharsPerColumn = 100000
+
+ val inputBufferSize = 128
+
+ val isCommentSet = this.comment != '\0'
+
+ val rowSeparator = "\n"
+}
+
+private[csv] object ParseModes {
+
+ val PERMISSIVE_MODE = "PERMISSIVE"
+ val DROP_MALFORMED_MODE = "DROPMALFORMED"
+ val FAIL_FAST_MODE = "FAILFAST"
+
+ val DEFAULT = PERMISSIVE_MODE
+
+ def isValidMode(mode: String): Boolean = {
+ mode.toUpperCase match {
+ case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true
+ case _ => false
+ }
+ }
+
+ def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE
+ def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE
+ def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) {
+ mode.toUpperCase == PERMISSIVE_MODE
+ } else {
+ true // We default to permissive is the mode string is not valid
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
new file mode 100644
index 0000000000..ba1cc42f3e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader}
+
+import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings}
+
+import org.apache.spark.Logging
+
+/**
+ * Read and parse CSV-like input
+ *
+ * @param params Parameters object
+ * @param headers headers for the columns
+ */
+private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) {
+
+ protected lazy val parser: CsvParser = {
+ val settings = new CsvParserSettings()
+ val format = settings.getFormat
+ format.setDelimiter(params.delimiter)
+ format.setLineSeparator(params.rowSeparator)
+ format.setQuote(params.quote)
+ format.setQuoteEscape(params.escape)
+ format.setComment(params.comment)
+ settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag)
+ settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag)
+ settings.setReadInputOnSeparateThread(false)
+ settings.setInputBufferSize(params.inputBufferSize)
+ settings.setMaxColumns(params.maxColumns)
+ settings.setNullValue(params.nullValue)
+ settings.setMaxCharsPerColumn(params.maxCharsPerColumn)
+ if (headers != null) settings.setHeaders(headers: _*)
+
+ new CsvParser(settings)
+ }
+}
+
+/**
+ * Converts a sequence of string to CSV string
+ *
+ * @param params Parameters object for configuration
+ * @param headers headers for columns
+ */
+private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging {
+ private val writerSettings = new CsvWriterSettings
+ private val format = writerSettings.getFormat
+
+ format.setDelimiter(params.delimiter)
+ format.setLineSeparator(params.rowSeparator)
+ format.setQuote(params.quote)
+ format.setQuoteEscape(params.escape)
+ format.setComment(params.comment)
+
+ writerSettings.setNullValue(params.nullValue)
+ writerSettings.setEmptyValue(params.nullValue)
+ writerSettings.setSkipEmptyLines(true)
+ writerSettings.setQuoteAllFields(false)
+ writerSettings.setHeaders(headers: _*)
+
+ def writeRow(row: Seq[String], includeHeader: Boolean): String = {
+ val buffer = new ByteArrayOutputStream()
+ val outputWriter = new OutputStreamWriter(buffer)
+ val writer = new CsvWriter(outputWriter, writerSettings)
+
+ if (includeHeader) {
+ writer.writeHeaders()
+ }
+ writer.writeRow(row.toArray: _*)
+ writer.close()
+ buffer.toString.stripLineEnd
+ }
+}
+
+/**
+ * Parser for parsing a line at a time. Not efficient for bulk data.
+ *
+ * @param params Parameters object
+ */
+private[sql] class LineCsvReader(params: CSVParameters)
+ extends CsvReader(params, null) {
+ /**
+ * parse a line
+ *
+ * @param line a String with no newline at the end
+ * @return array of strings where each string is a field in the CSV record
+ */
+ def parseLine(line: String): Array[String] = {
+ parser.beginParsing(new StringReader(line))
+ val parsed = parser.parseNext()
+ parser.stopParsing()
+ parsed
+ }
+}
+
+/**
+ * Parser for parsing lines in bulk. Use this when efficiency is desired.
+ *
+ * @param iter iterator over lines in the file
+ * @param params Parameters object
+ * @param headers headers for the columns
+ */
+private[sql] class BulkCsvReader(
+ iter: Iterator[String],
+ params: CSVParameters,
+ headers: Seq[String])
+ extends CsvReader(params, headers) with Iterator[Array[String]] {
+
+ private val reader = new StringIteratorReader(iter)
+ parser.beginParsing(reader)
+ private var nextRecord = parser.parseNext()
+
+ /**
+ * get the next parsed line.
+ * @return array of strings where each string is a field in the CSV record
+ */
+ override def next(): Array[String] = {
+ val curRecord = nextRecord
+ if(curRecord != null) {
+ nextRecord = parser.parseNext()
+ } else {
+ throw new NoSuchElementException("next record is null")
+ }
+ curRecord
+ }
+
+ override def hasNext: Boolean = nextRecord != null
+
+}
+
+/**
+ * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at
+ * end of each line Univocity parser requires a Reader that provides access to the data to be
+ * parsed and needs the newlines to be present
+ * @param iter iterator over RDD[String]
+ */
+private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader {
+
+ private var next: Long = 0
+ private var length: Long = 0 // length of input so far
+ private var start: Long = 0
+ private var str: String = null // current string from iter
+
+ /**
+ * fetch next string from iter, if done with current one
+ * pretend there is a new line at the end of every string we get from from iter
+ */
+ private def refill(): Unit = {
+ if (length == next) {
+ if (iter.hasNext) {
+ str = iter.next()
+ start = length
+ length += (str.length + 1) // allowance for newline removed by SparkContext.textFile()
+ } else {
+ str = null
+ }
+ }
+ }
+
+ /**
+ * read the next character, if at end of string pretend there is a new line
+ */
+ override def read(): Int = {
+ refill()
+ if (next >= length) {
+ -1
+ } else {
+ val cur = next - start
+ next += 1
+ if (cur == str.length) '\n' else str.charAt(cur.toInt)
+ }
+ }
+
+ /**
+ * read from str into cbuf
+ */
+ override def read(cbuf: Array[Char], off: Int, len: Int): Int = {
+ refill()
+ var n = 0
+ if ((off < 0) || (off > cbuf.length) || (len < 0) ||
+ ((off + len) > cbuf.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException()
+ } else if (len == 0) {
+ n = 0
+ } else {
+ if (next >= length) { // end of input
+ n = -1
+ } else {
+ n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size
+ if (n == length - next) {
+ str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off)
+ cbuf(off + n - 1) = '\n'
+ } else {
+ str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off)
+ }
+ next += n
+ if (n < len) {
+ val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter
+ if(m != -1) n += m
+ }
+ }
+ }
+
+ n
+ }
+
+ override def skip(ns: Long): Long = {
+ throw new IllegalArgumentException("Skip not implemented")
+ }
+
+ override def ready: Boolean = {
+ refill()
+ true
+ }
+
+ override def markSupported: Boolean = false
+
+ override def mark(readAheadLimit: Int): Unit = {
+ throw new IllegalArgumentException("Mark not implemented")
+ }
+
+ override def reset(): Unit = {
+ throw new IllegalArgumentException("Mark and hence reset not implemented")
+ }
+
+ override def close(): Unit = { }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
new file mode 100644
index 0000000000..9267479755
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.nio.charset.Charset
+
+import scala.util.control.NonFatal
+
+import com.google.common.base.Objects
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.RecordWriter
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
+
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+private[csv] class CSVRelation(
+ private val inputRDD: Option[RDD[String]],
+ override val paths: Array[String],
+ private val maybeDataSchema: Option[StructType],
+ override val userDefinedPartitionColumns: Option[StructType],
+ private val parameters: Map[String, String])
+ (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable {
+
+ override lazy val dataSchema: StructType = maybeDataSchema match {
+ case Some(structType) => structType
+ case None => inferSchema(paths)
+ }
+
+ private val params = new CSVParameters(parameters)
+
+ @transient
+ private var cachedRDD: Option[RDD[String]] = None
+
+ private def readText(location: String): RDD[String] = {
+ if (Charset.forName(params.charset) == Charset.forName("UTF-8")) {
+ sqlContext.sparkContext.textFile(location)
+ } else {
+ sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location)
+ .mapPartitions { _.map { pair =>
+ new String(pair._2.getBytes, 0, pair._2.getLength, params.charset)
+ }
+ }
+ }
+ }
+
+ private def baseRdd(inputPaths: Array[String]): RDD[String] = {
+ inputRDD.getOrElse {
+ cachedRDD.getOrElse {
+ val rdd = readText(inputPaths.mkString(","))
+ cachedRDD = Some(rdd)
+ rdd
+ }
+ }
+ }
+
+ private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = {
+ val rdd = baseRdd(inputPaths)
+ // Make sure firstLine is materialized before sending to executors
+ val firstLine = if (params.headerFlag) findFirstLine(rdd) else null
+ CSVRelation.univocityTokenizer(rdd, header, firstLine, params)
+ }
+
+ /**
+ * This supports to eliminate unneeded columns before producing an RDD
+ * containing all of its tuples as Row objects. This reads all the tokens of each line
+ * and then drop unneeded tokens without casting and type-checking by mapping
+ * both the indices produced by `requiredColumns` and the ones of tokens.
+ * TODO: Switch to using buildInternalScan
+ */
+ override def buildScan(requiredColumns: Array[String], inputs: Array[FileStatus]): RDD[Row] = {
+ val pathsString = inputs.map(_.getPath.toUri.toString)
+ val header = schema.fields.map(_.name)
+ val tokenizedRdd = tokenRdd(header, pathsString)
+ CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params)
+ }
+
+ override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ new CSVOutputWriterFactory(params)
+ }
+
+ override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns)
+
+ override def equals(other: Any): Boolean = other match {
+ case that: CSVRelation => {
+ val equalPath = paths.toSet == that.paths.toSet
+ val equalDataSchema = dataSchema == that.dataSchema
+ val equalSchema = schema == that.schema
+ val equalPartitionColums = partitionColumns == that.partitionColumns
+
+ equalPath && equalDataSchema && equalSchema && equalPartitionColums
+ }
+ case _ => false
+ }
+
+ private def inferSchema(paths: Array[String]): StructType = {
+ val rdd = baseRdd(Array(paths.head))
+ val firstLine = findFirstLine(rdd)
+ val firstRow = new LineCsvReader(params).parseLine(firstLine)
+
+ val header = if (params.headerFlag) {
+ firstRow
+ } else {
+ firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
+ }
+
+ val parsedRdd = tokenRdd(header, paths)
+ if (params.inferSchemaFlag) {
+ CSVInferSchema(parsedRdd, header, params.nullValue)
+ } else {
+ // By default fields are assumed to be StringType
+ val schemaFields = header.map { fieldName =>
+ StructField(fieldName.toString, StringType, nullable = true)
+ }
+ StructType(schemaFields)
+ }
+ }
+
+ /**
+ * Returns the first line of the first non-empty file in path
+ */
+ private def findFirstLine(rdd: RDD[String]): String = {
+ if (params.isCommentSet) {
+ rdd.take(params.MAX_COMMENT_LINES_IN_HEADER)
+ .find(!_.startsWith(params.comment.toString))
+ .getOrElse(sys.error(s"No uncommented header line in " +
+ s"first ${params.MAX_COMMENT_LINES_IN_HEADER} lines"))
+ } else {
+ rdd.first()
+ }
+ }
+}
+
+object CSVRelation extends Logging {
+
+ def univocityTokenizer(
+ file: RDD[String],
+ header: Seq[String],
+ firstLine: String,
+ params: CSVParameters): RDD[Array[String]] = {
+ // If header is set, make sure firstLine is materialized before sending to executors.
+ file.mapPartitionsWithIndex({
+ case (split, iter) => new BulkCsvReader(
+ if (params.headerFlag) iter.filterNot(_ == firstLine) else iter,
+ params,
+ headers = header)
+ }, true)
+ }
+
+ def parseCsv(
+ tokenizedRDD: RDD[Array[String]],
+ schema: StructType,
+ requiredColumns: Array[String],
+ inputs: Array[FileStatus],
+ sqlContext: SQLContext,
+ params: CSVParameters): RDD[Row] = {
+
+ val schemaFields = schema.fields
+ val requiredFields = StructType(requiredColumns.map(schema(_))).fields
+ val safeRequiredFields = if (params.dropMalformed) {
+ // If `dropMalformed` is enabled, then it needs to parse all the values
+ // so that we can decide which row is malformed.
+ requiredFields ++ schemaFields.filterNot(requiredFields.contains(_))
+ } else {
+ requiredFields
+ }
+ if (requiredColumns.isEmpty) {
+ sqlContext.sparkContext.emptyRDD[Row]
+ } else {
+ val safeRequiredIndices = new Array[Int](safeRequiredFields.length)
+ schemaFields.zipWithIndex.filter {
+ case (field, _) => safeRequiredFields.contains(field)
+ }.foreach {
+ case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
+ }
+ val rowArray = new Array[Any](safeRequiredIndices.length)
+ val requiredSize = requiredFields.length
+ tokenizedRDD.flatMap { tokens =>
+ if (params.dropMalformed && schemaFields.length != tokens.size) {
+ logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
+ } else if (params.failFast && schemaFields.length != tokens.size) {
+ throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
+ s"${tokens.mkString(params.delimiter.toString)}")
+ } else {
+ val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) {
+ tokens ++ new Array[String](schemaFields.length - tokens.size)
+ } else if (params.permissive && schemaFields.length < tokens.size) {
+ tokens.take(schemaFields.length)
+ } else {
+ tokens
+ }
+ try {
+ var index: Int = 0
+ var subIndex: Int = 0
+ while (subIndex < safeRequiredIndices.length) {
+ index = safeRequiredIndices(subIndex)
+ val field = schemaFields(index)
+ rowArray(subIndex) = CSVTypeCast.castTo(
+ indexSafeTokens(index),
+ field.dataType,
+ field.nullable,
+ params.nullValue)
+ subIndex = subIndex + 1
+ }
+ Some(Row.fromSeq(rowArray.take(requiredSize)))
+ } catch {
+ case NonFatal(e) if params.dropMalformed =>
+ logWarning("Parse exception. " +
+ s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
+ None
+ }
+ }
+ }
+ }
+ }
+}
+
+private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new CsvOutputWriter(path, dataSchema, context, params)
+ }
+}
+
+private[sql] class CsvOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext,
+ params: CSVParameters) extends OutputWriter with Logging {
+
+ // create the Generator without separator inserted between 2 records
+ private[this] val text = new Text()
+
+ private val recordWriter: RecordWriter[NullWritable, Text] = {
+ new TextOutputFormat[NullWritable, Text]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val configuration = context.getConfiguration
+ val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
+ val taskAttemptId = context.getTaskAttemptID
+ val split = taskAttemptId.getTaskID.getId
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ }
+ }.getRecordWriter(context)
+ }
+
+ private var firstRow: Boolean = params.headerFlag
+
+ private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
+
+ private def rowToString(row: Seq[Any]): Seq[String] = row.map { field =>
+ if (field != null) {
+ field.toString
+ } else {
+ params.nullValue
+ }
+ }
+
+ override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
+
+ override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ // TODO: Instead of converting and writing every row, we should use the univocity buffer
+ val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow)
+ if (firstRow) {
+ firstRow = false
+ }
+ text.set(resultString)
+ recordWriter.write(NullWritable.get(), text)
+ }
+
+ override def close(): Unit = {
+ recordWriter.close(context)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
new file mode 100644
index 0000000000..2fffae452c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Provides access to CSV data from pure SQL statements.
+ */
+class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+
+ override def shortName(): String = "csv"
+
+ /**
+ * Creates a new relation for data store in CSV given parameters and user supported schema.
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation = {
+
+ new CSVRelation(
+ None,
+ paths,
+ dataSchema,
+ partitionColumns,
+ parameters)(sqlContext)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 59ba4ae2cb..44d5e4ff7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -145,7 +145,7 @@ private[json] object InferSchema {
/**
* Convert NullType to StringType and remove StructTypes with no fields
*/
- private def canonicalizeType: DataType => Option[DataType] = {
+ private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
case at @ ArrayType(elementType, _) =>
for {
canonicalType <- canonicalizeType(elementType)
@@ -154,15 +154,15 @@ private[json] object InferSchema {
}
case StructType(fields) =>
- val canonicalFields = for {
+ val canonicalFields: Array[StructField] = for {
field <- fields
- if field.name.nonEmpty
+ if field.name.length > 0
canonicalType <- canonicalizeType(field.dataType)
} yield {
field.copy(dataType = canonicalType)
}
- if (canonicalFields.nonEmpty) {
+ if (canonicalFields.length > 0) {
Some(StructType(canonicalFields))
} else {
// per SPARK-8093: empty structs should be deleted
@@ -217,10 +217,9 @@ private[json] object InferSchema {
(t1, t2) match {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
- case (DoubleType, t: DecimalType) =>
- DoubleType
- case (t: DecimalType, DoubleType) =>
+ case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
DoubleType
+
case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/cars-alternative.csv
new file mode 100644
index 0000000000..646f7c456c
--- /dev/null
+++ b/sql/core/src/test/resources/cars-alternative.csv
@@ -0,0 +1,5 @@
+year|make|model|comment|blank
+'2012'|'Tesla'|'S'| 'No comment'|
+
+1997|Ford|E350|'Go get one now they are going fast'|
+2015|Chevy|Volt
diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/cars-null.csv
new file mode 100644
index 0000000000..130c0b40bb
--- /dev/null
+++ b/sql/core/src/test/resources/cars-null.csv
@@ -0,0 +1,6 @@
+year,make,model,comment,blank
+"2012","Tesla","S",null,
+
+1997,Ford,E350,"Go get one now they are going fast",
+null,Chevy,Volt
+
diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/cars-unbalanced-quotes.csv
new file mode 100644
index 0000000000..5ea39fcbfa
--- /dev/null
+++ b/sql/core/src/test/resources/cars-unbalanced-quotes.csv
@@ -0,0 +1,4 @@
+year,make,model,comment,blank
+"2012,Tesla,S,No comment
+1997,Ford,E350,Go get one now they are going fast"
+"2015,"Chevy",Volt,
diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/cars.csv
new file mode 100644
index 0000000000..2b9d74ca60
--- /dev/null
+++ b/sql/core/src/test/resources/cars.csv
@@ -0,0 +1,6 @@
+year,make,model,comment,blank
+"2012","Tesla","S","No comment",
+
+1997,Ford,E350,"Go get one now they are going fast",
+2015,Chevy,Volt
+
diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/cars.tsv
new file mode 100644
index 0000000000..a7bfa9a91f
--- /dev/null
+++ b/sql/core/src/test/resources/cars.tsv
@@ -0,0 +1,4 @@
+year make model price comment blank
+2012 Tesla S "80,000.65"
+1997 Ford E350 35,000 "Go get one now they are going fast"
+2015 Chevy Volt 5,000.10
diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/cars_iso-8859-1.csv
new file mode 100644
index 0000000000..c51b6c5901
--- /dev/null
+++ b/sql/core/src/test/resources/cars_iso-8859-1.csv
@@ -0,0 +1,6 @@
+yearþmakeþmodelþcommentþblank
+"2012"þ"Tesla"þ"S"þ"No comment"þ
+
+1997þFordþE350þ"Go get one now they are þoing fast"þ
+2015þChevyþVolt
+
diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/comments.csv
new file mode 100644
index 0000000000..6275be7285
--- /dev/null
+++ b/sql/core/src/test/resources/comments.csv
@@ -0,0 +1,6 @@
+~ Version 1.0
+~ Using a non-standard comment char to test CSV parser defaults are overridden
+1,2,3,4,5.01,2015-08-20 15:57:00
+6,7,8,9,0,2015-08-21 16:58:01
+~0,9,8,7,6,2015-08-22 17:59:02
+1,2,3,4,5,2015-08-23 18:00:42
diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/disable_comments.csv
new file mode 100644
index 0000000000..304d406e4d
--- /dev/null
+++ b/sql/core/src/test/resources/disable_comments.csv
@@ -0,0 +1,2 @@
+#1,2,3
+4,5,6
diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/empty.csv
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/core/src/test/resources/empty.csv
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
new file mode 100644
index 0000000000..a1796f1326
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class InferSchemaSuite extends SparkFunSuite {
+
+ test("String fields types are inferred correctly from null types") {
+ assert(CSVInferSchema.inferField(NullType, "") == NullType)
+ assert(CSVInferSchema.inferField(NullType, null) == NullType)
+ assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType)
+ assert(CSVInferSchema.inferField(NullType, "60") == IntegerType)
+ assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
+ assert(CSVInferSchema.inferField(NullType, "test") == StringType)
+ assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
+ }
+
+ test("String fields types are inferred correctly from other types") {
+ assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType)
+ assert(CSVInferSchema.inferField(LongType, "test") == StringType)
+ assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType)
+ assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType)
+ assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
+ assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
+ assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
+ }
+
+ test("Timestamp field types are inferred correctly from other types") {
+ assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType)
+ assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType)
+ assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
+ }
+
+ test("Type arrays are merged to highest common type") {
+ assert(
+ CSVInferSchema.mergeRowTypes(Array(StringType),
+ Array(DoubleType)).deep == Array(StringType).deep)
+ assert(
+ CSVInferSchema.mergeRowTypes(Array(IntegerType),
+ Array(LongType)).deep == Array(LongType).deep)
+ assert(
+ CSVInferSchema.mergeRowTypes(Array(DoubleType),
+ Array(LongType)).deep == Array(DoubleType).deep)
+ }
+
+ test("Null fields are handled properly when a nullValue is specified") {
+ assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType)
+ assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType)
+ assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType)
+ assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
+ assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
+ assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
new file mode 100644
index 0000000000..c0c38c6787
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * test cases for StringIteratorReader
+ */
+class CSVParserSuite extends SparkFunSuite {
+
+ private def readAll(iter: Iterator[String]) = {
+ val reader = new StringIteratorReader(iter)
+ var c: Int = -1
+ val read = new scala.collection.mutable.StringBuilder()
+ do {
+ c = reader.read()
+ read.append(c.toChar)
+ } while (c != -1)
+
+ read.dropRight(1).toString
+ }
+
+ private def readBufAll(iter: Iterator[String], bufSize: Int) = {
+ val reader = new StringIteratorReader(iter)
+ val cbuf = new Array[Char](bufSize)
+ val read = new scala.collection.mutable.StringBuilder()
+
+ var done = false
+ do { // read all input one cbuf at a time
+ var numRead = 0
+ var n = 0
+ do { // try to fill cbuf
+ var off = 0
+ var len = cbuf.length
+ n = reader.read(cbuf, off, len)
+
+ if (n != -1) {
+ off += n
+ len -= n
+ }
+
+ assert(len >= 0 && len <= cbuf.length)
+ assert(off >= 0 && off <= cbuf.length)
+ read.appendAll(cbuf.take(n))
+ } while (n > 0)
+ if(n != -1) {
+ numRead += n
+ } else {
+ done = true
+ }
+ } while (!done)
+
+ read.toString
+ }
+
+ test("Hygiene") {
+ val reader = new StringIteratorReader(List("").toIterator)
+ assert(reader.ready === true)
+ assert(reader.markSupported === false)
+ intercept[IllegalArgumentException] { reader.skip(1) }
+ intercept[IllegalArgumentException] { reader.mark(1) }
+ intercept[IllegalArgumentException] { reader.reset() }
+ }
+
+ test("Regular case") {
+ val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
+ val read = readAll(input.toIterator)
+ assert(read === input.mkString("\n") ++ ("\n"))
+ }
+
+ test("Empty iter") {
+ val input = List[String]()
+ val read = readAll(input.toIterator)
+ assert(read === "")
+ }
+
+ test("Embedded new line") {
+ val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
+ val read = readAll(input.toIterator)
+ assert(read === input.mkString("\n") ++ ("\n"))
+ }
+
+ test("Buffer Regular case") {
+ val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
+ val output = input.mkString("\n") ++ ("\n")
+ for(i <- 1 to output.length + 5) {
+ val read = readBufAll(input.toIterator, i)
+ assert(read === output)
+ }
+ }
+
+ test("Buffer Empty iter") {
+ val input = List[String]()
+ val output = ""
+ for(i <- 1 to output.length + 5) {
+ val read = readBufAll(input.toIterator, 1)
+ assert(read === "")
+ }
+ }
+
+ test("Buffer Embedded new line") {
+ val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
+ val output = input.mkString("\n") ++ ("\n")
+ for(i <- 1 to output.length + 5) {
+ val read = readBufAll(input.toIterator, 1)
+ assert(read === output)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
new file mode 100644
index 0000000000..8fdd31aa43
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.File
+import java.nio.charset.UnsupportedCharsetException
+import java.sql.Timestamp
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.types._
+
+class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
+ private val carsFile = "cars.csv"
+ private val carsFile8859 = "cars_iso-8859-1.csv"
+ private val carsTsvFile = "cars.tsv"
+ private val carsAltFile = "cars-alternative.csv"
+ private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv"
+ private val carsNullFile = "cars-null.csv"
+ private val emptyFile = "empty.csv"
+ private val commentsFile = "comments.csv"
+ private val disableCommentsFile = "disable_comments.csv"
+
+ private def testFile(fileName: String): String = {
+ Thread.currentThread().getContextClassLoader.getResource(fileName).toString
+ }
+
+ /** Verifies data and schema. */
+ private def verifyCars(
+ df: DataFrame,
+ withHeader: Boolean,
+ numCars: Int = 3,
+ numFields: Int = 5,
+ checkHeader: Boolean = true,
+ checkValues: Boolean = true,
+ checkTypes: Boolean = false): Unit = {
+
+ val numColumns = numFields
+ val numRows = if (withHeader) numCars else numCars + 1
+ // schema
+ assert(df.schema.fieldNames.length === numColumns)
+ assert(df.collect().length === numRows)
+
+ if (checkHeader) {
+ if (withHeader) {
+ assert(df.schema.fieldNames === Array("year", "make", "model", "comment", "blank"))
+ } else {
+ assert(df.schema.fieldNames === Array("C0", "C1", "C2", "C3", "C4"))
+ }
+ }
+
+ if (checkValues) {
+ val yearValues = List("2012", "1997", "2015")
+ val actualYears = if (!withHeader) "year" :: yearValues else yearValues
+ val years = if (withHeader) df.select("year").collect() else df.select("C0").collect()
+
+ years.zipWithIndex.foreach { case (year, index) =>
+ if (checkTypes) {
+ assert(year === Row(actualYears(index).toInt))
+ } else {
+ assert(year === Row(actualYears(index)))
+ }
+ }
+ }
+ }
+
+ test("simple csv test") {
+ val cars = sqlContext
+ .read
+ .format("csv")
+ .option("header", "false")
+ .load(testFile(carsFile))
+
+ verifyCars(cars, withHeader = false, checkTypes = false)
+ }
+
+ test("simple csv test with type inference") {
+ val cars = sqlContext
+ .read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .load(testFile(carsFile))
+
+ verifyCars(cars, withHeader = true, checkTypes = true)
+ }
+
+ test("test with alternative delimiter and quote") {
+ val cars = sqlContext.read
+ .format("csv")
+ .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true"))
+ .load(testFile(carsAltFile))
+
+ verifyCars(cars, withHeader = true)
+ }
+
+ test("bad encoding name") {
+ val exception = intercept[UnsupportedCharsetException] {
+ sqlContext
+ .read
+ .format("csv")
+ .option("charset", "1-9588-osi")
+ .load(testFile(carsFile8859))
+ }
+
+ assert(exception.getMessage.contains("1-9588-osi"))
+ }
+
+ ignore("test different encoding") {
+ // scalastyle:off
+ sqlContext.sql(
+ s"""
+ |CREATE TEMPORARY TABLE carsTable USING csv
+ |OPTIONS (path "${testFile(carsFile8859)}", header "true",
+ |charset "iso-8859-1", delimiter "þ")
+ """.stripMargin.replaceAll("\n", " "))
+ //scalstyle:on
+
+ verifyCars(sqlContext.table("carsTable"), withHeader = true)
+ }
+
+ test("DDL test with tab separated file") {
+ sqlContext.sql(
+ s"""
+ |CREATE TEMPORARY TABLE carsTable USING csv
+ |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t")
+ """.stripMargin.replaceAll("\n", " "))
+
+ verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false)
+ }
+
+ test("DDL test parsing decimal type") {
+ sqlContext.sql(
+ s"""
+ |CREATE TEMPORARY TABLE carsTable
+ |(yearMade double, makeName string, modelName string, priceTag decimal,
+ | comments string, grp string)
+ |USING csv
+ |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t")
+ """.stripMargin.replaceAll("\n", " "))
+
+ assert(
+ sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1)
+ }
+
+ test("test for DROPMALFORMED parsing mode") {
+ val cars = sqlContext.read
+ .format("csv")
+ .options(Map("header" -> "true", "mode" -> "dropmalformed"))
+ .load(testFile(carsFile))
+
+ assert(cars.select("year").collect().size === 2)
+ }
+
+ test("test for FAILFAST parsing mode") {
+ val exception = intercept[SparkException]{
+ sqlContext.read
+ .format("csv")
+ .options(Map("header" -> "true", "mode" -> "failfast"))
+ .load(testFile(carsFile)).collect()
+ }
+
+ assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
+ }
+
+ test("test with null quote character") {
+ val cars = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .option("quote", "")
+ .load(testFile(carsUnbalancedQuotesFile))
+
+ verifyCars(cars, withHeader = true, checkValues = false)
+
+ }
+
+ test("test with empty file and known schema") {
+ val result = sqlContext.read
+ .format("csv")
+ .schema(StructType(List(StructField("column", StringType, false))))
+ .load(testFile(emptyFile))
+
+ assert(result.collect.size === 0)
+ assert(result.schema.fieldNames.size === 1)
+ }
+
+
+ test("DDL test with empty file") {
+ sqlContext.sql(s"""
+ |CREATE TEMPORARY TABLE carsTable
+ |(yearMade double, makeName string, modelName string, comments string, grp string)
+ |USING csv
+ |OPTIONS (path "${testFile(emptyFile)}", header "false")
+ """.stripMargin.replaceAll("\n", " "))
+
+ assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0)
+ }
+
+ test("DDL test with schema") {
+ sqlContext.sql(s"""
+ |CREATE TEMPORARY TABLE carsTable
+ |(yearMade double, makeName string, modelName string, comments string, blank string)
+ |USING csv
+ |OPTIONS (path "${testFile(carsFile)}", header "true")
+ """.stripMargin.replaceAll("\n", " "))
+
+ val cars = sqlContext.table("carsTable")
+ verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false)
+ assert(
+ cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank"))
+ }
+
+ test("save csv") {
+ withTempDir { dir =>
+ val csvDir = new File(dir, "csv").getCanonicalPath
+ val cars = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(testFile(carsFile))
+
+ cars.coalesce(1).write
+ .format("csv")
+ .option("header", "true")
+ .save(csvDir)
+
+ val carsCopy = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(csvDir)
+
+ verifyCars(carsCopy, withHeader = true)
+ }
+ }
+
+ test("save csv with quote") {
+ withTempDir { dir =>
+ val csvDir = new File(dir, "csv").getCanonicalPath
+ val cars = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(testFile(carsFile))
+
+ cars.coalesce(1).write
+ .format("csv")
+ .option("header", "true")
+ .option("quote", "\"")
+ .save(csvDir)
+
+ val carsCopy = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .option("quote", "\"")
+ .load(csvDir)
+
+ verifyCars(carsCopy, withHeader = true)
+ }
+ }
+
+ test("commented lines in CSV data") {
+ val results = sqlContext.read
+ .format("csv")
+ .options(Map("comment" -> "~", "header" -> "false"))
+ .load(testFile(commentsFile))
+ .collect()
+
+ val expected =
+ Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"),
+ Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"),
+ Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42"))
+
+ assert(results.toSeq.map(_.toSeq) === expected)
+ }
+
+ test("inferring schema with commented lines in CSV data") {
+ val results = sqlContext.read
+ .format("csv")
+ .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true"))
+ .load(testFile(commentsFile))
+ .collect()
+
+ val expected =
+ Seq(Seq(1, 2, 3, 4, 5.01D, Timestamp.valueOf("2015-08-20 15:57:00")),
+ Seq(6, 7, 8, 9, 0, Timestamp.valueOf("2015-08-21 16:58:01")),
+ Seq(1, 2, 3, 4, 5, Timestamp.valueOf("2015-08-23 18:00:42")))
+
+ assert(results.toSeq.map(_.toSeq) === expected)
+ }
+
+ test("setting comment to null disables comment support") {
+ val results = sqlContext.read
+ .format("csv")
+ .options(Map("comment" -> "", "header" -> "false"))
+ .load(testFile(disableCommentsFile))
+ .collect()
+
+ val expected =
+ Seq(
+ Seq("#1", "2", "3"),
+ Seq("4", "5", "6"))
+
+ assert(results.toSeq.map(_.toSeq) === expected)
+ }
+
+ test("nullable fields with user defined null value of \"null\"") {
+
+ // year,make,model,comment,blank
+ val dataSchema = StructType(List(
+ StructField("year", IntegerType, nullable = true),
+ StructField("make", StringType, nullable = false),
+ StructField("model", StringType, nullable = false),
+ StructField("comment", StringType, nullable = true),
+ StructField("blank", StringType, nullable = true)))
+ val cars = sqlContext.read
+ .format("csv")
+ .schema(dataSchema)
+ .options(Map("header" -> "true", "nullValue" -> "null"))
+ .load(testFile(carsNullFile))
+
+ verifyCars(cars, withHeader = true, checkValues = false)
+ val results = cars.collect()
+ assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null"))
+ assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
new file mode 100644
index 0000000000..40c5ccd0f7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.math.BigDecimal
+import java.sql.{Date, Timestamp}
+import java.util.Locale
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class CSVTypeCastSuite extends SparkFunSuite {
+
+ test("Can parse decimal type values") {
+ val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
+ val decimalValues = Seq(10.05, 1000.01, 158058049.001)
+ val decimalType = new DecimalType()
+
+ stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
+ assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))
+ }
+ }
+
+ test("Can parse escaped characters") {
+ assert(CSVTypeCast.toChar("""\t""") === '\t')
+ assert(CSVTypeCast.toChar("""\r""") === '\r')
+ assert(CSVTypeCast.toChar("""\b""") === '\b')
+ assert(CSVTypeCast.toChar("""\f""") === '\f')
+ assert(CSVTypeCast.toChar("""\"""") === '\"')
+ assert(CSVTypeCast.toChar("""\'""") === '\'')
+ assert(CSVTypeCast.toChar("""\u0000""") === '\u0000')
+ }
+
+ test("Does not accept delimiter larger than one character") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVTypeCast.toChar("ab")
+ }
+ assert(exception.getMessage.contains("cannot be more than one character"))
+ }
+
+ test("Throws exception for unsupported escaped characters") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVTypeCast.toChar("""\1""")
+ }
+ assert(exception.getMessage.contains("Unsupported special character for delimiter"))
+ }
+
+ test("Nullable types are handled") {
+ assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null)
+ }
+
+ test("String type should always return the same as the input") {
+ assert(CSVTypeCast.castTo("", StringType, nullable = true) == "")
+ assert(CSVTypeCast.castTo("", StringType, nullable = false) == "")
+ }
+
+ test("Throws exception for empty string with non null type") {
+ val exception = intercept[NumberFormatException]{
+ CSVTypeCast.castTo("", IntegerType, nullable = false)
+ }
+ assert(exception.getMessage.contains("For input string: \"\""))
+ }
+
+ test("Types are cast correctly") {
+ assert(CSVTypeCast.castTo("10", ByteType) == 10)
+ assert(CSVTypeCast.castTo("10", ShortType) == 10)
+ assert(CSVTypeCast.castTo("10", IntegerType) == 10)
+ assert(CSVTypeCast.castTo("10", LongType) == 10)
+ assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0)
+ assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
+ assert(CSVTypeCast.castTo("true", BooleanType) == true)
+ val timestamp = "2015-01-01 00:00:00"
+ assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
+ assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
+ }
+
+ test("Float and Double Types are cast correctly with Locale") {
+ val locale : Locale = new Locale("fr", "FR")
+ Locale.setDefault(locale)
+ assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0)
+ assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0)
+ }
+}