diff options
author | Hossein <hossein@databricks.com> | 2016-01-15 11:46:46 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-15 11:46:46 -0800 |
commit | 5f83c6991c95616ecbc2878f8860c69b2826f56c (patch) | |
tree | 86dc70e45f1b27b67efec9724632a108d69f2ef0 /sql | |
parent | c5e7076da72657ea35a0aa388f8d2e6411d39280 (diff) | |
download | spark-5f83c6991c95616ecbc2878f8860c69b2826f56c.tar.gz spark-5f83c6991c95616ecbc2878f8860c69b2826f56c.tar.bz2 spark-5f83c6991c95616ecbc2878f8860c69b2826f56c.zip |
[SPARK-12833][SQL] Initial import of spark-csv
CSV is the most common data format in the "small data" world. It is often the first format people want to try when they see Spark on a single node. Having to rely on a 3rd party component for this leads to poor user experience for new users. This PR merges the popular spark-csv data source package (https://github.com/databricks/spark-csv) with SparkSQL.
This is a first PR to bring the functionality to spark 2.0 master. We will complete items outlines in the design document (see JIRA attachment) in follow up pull requests.
Author: Hossein <hossein@databricks.com>
Author: Reynold Xin <rxin@databricks.com>
Closes #10766 from rxin/csv.
Diffstat (limited to 'sql')
21 files changed, 1610 insertions, 7 deletions
diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 6db7a8a2dc..31b364f351 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -37,6 +37,12 @@ <dependencies> <dependency> + <groupId>com.univocity</groupId> + <artifactId>univocity-parsers</artifactId> + <version>1.5.6</version> + <type>jar</type> + </dependency> + <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1ca2044057..226d59d0ea 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,4 @@ +org.apache.spark.sql.execution.datasources.csv.DefaultSource org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala new file mode 100644 index 0000000000..0aa4539e60 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.text.NumberFormat +import java.util.Locale + +import scala.util.control.Exception._ +import scala.util.Try + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.types._ + + +private[sql] object CSVInferSchema { + + /** + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670] + */ + def apply( + tokenRdd: RDD[Array[String]], + header: Array[String], + nullValue: String = ""): StructType = { + + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) + + val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => + StructField(thisHeader, rootType, nullable = true) + } + + StructType(structFields) + } + + private def inferRowType(nullValue: String) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + var i = 0 + while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. + rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) + i+=1 + } + rowSoFar + } + + private[csv] def mergeRowTypes( + first: Array[DataType], + second: Array[DataType]): Array[DataType] = { + + first.zipAll(second, NullType, NullType).map { case ((a, b)) => + val tpe = findTightestCommonType(a, b).getOrElse(StringType) + tpe match { + case _: NullType => StringType + case other => other + } + } + } + + /** + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ + private[csv] def inferField( + typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + if (field == null || field.isEmpty || field == nullValue) { + typeSoFar + } else { + typeSoFar match { + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case StringType => StringType + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + } + } + + private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) { + IntegerType + } else { + tryParseLong(field) + } + + private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) { + LongType + } else { + tryParseDouble(field) + } + + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined) { + DoubleType + } else { + tryParseTimestamp(field) + } + } + + def tryParseTimestamp(field: String): DataType = { + if ((allCatch opt Timestamp.valueOf(field)).isDefined) { + TimestampType + } else { + stringType() + } + } + + // Defining a function to return the StringType constant is necessary in order to work around + // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions; + // see issue #128 for more details. + private def stringType(): DataType = { + StringType + } + + private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + case _ => None + } +} + +object CSVTypeCast { + + /** + * Casts given string datum to specified type. + * Currently we do not support complex types (ArrayType, MapType, StructType). + * + * For string types, this is simply the datum. For other types. + * For other nullable types, this is null if the string datum is empty. + * + * @param datum string value + * @param castType SparkSQL type + */ + private[csv] def castTo( + datum: String, + castType: DataType, + nullable: Boolean = true, + nullValue: String = ""): Any = { + + if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + * + */ + @throws[IllegalArgumentException] + private[csv] def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala new file mode 100644 index 0000000000..ba44121244 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import org.apache.spark.Logging + +private[sql] case class CSVParameters(parameters: Map[String, String]) extends Logging { + + private def getChar(paramName: String, default: Char): Char = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) if value.length == 0 => '\0' + case Some(value) if value.length == 1 => value.charAt(0) + case _ => throw new RuntimeException(s"$paramName cannot be more than one character") + } + } + + private def getBool(paramName: String, default: Boolean = false): Boolean = { + val param = parameters.getOrElse(paramName, default.toString) + if (param.toLowerCase() == "true") { + true + } else if (param.toLowerCase == "false") { + false + } else { + throw new Exception(s"$paramName flag can be true or false") + } + } + + val delimiter = CSVTypeCast.toChar(parameters.getOrElse("delimiter", ",")) + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val charset = parameters.getOrElse("charset", Charset.forName("UTF-8").name()) + + val quote = getChar("quote", '\"') + val escape = getChar("escape", '\\') + val comment = getChar("comment", '\0') + + val headerFlag = getBool("header") + val inferSchemaFlag = getBool("inferSchema") + val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") + val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + + // Limit the number of lines we'll search for a header row that isn't comment-prefixed + val MAX_COMMENT_LINES_IN_HEADER = 10 + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) + + val nullValue = parameters.getOrElse("nullValue", "") + + val maxColumns = 20480 + + val maxCharsPerColumn = 100000 + + val inputBufferSize = 128 + + val isCommentSet = this.comment != '\0' + + val rowSeparator = "\n" +} + +private[csv] object ParseModes { + + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala new file mode 100644 index 0000000000..ba1cc42f3e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings} + +import org.apache.spark.Logging + +/** + * Read and parse CSV-like input + * + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) { + + protected lazy val parser: CsvParser = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag) + settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(params.inputBufferSize) + settings.setMaxColumns(params.maxColumns) + settings.setNullValue(params.nullValue) + settings.setMaxCharsPerColumn(params.maxCharsPerColumn) + if (headers != null) settings.setHeaders(headers: _*) + + new CsvParser(settings) + } +} + +/** + * Converts a sequence of string to CSV string + * + * @param params Parameters object for configuration + * @param headers headers for columns + */ +private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging { + private val writerSettings = new CsvWriterSettings + private val format = writerSettings.getFormat + + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + + writerSettings.setNullValue(params.nullValue) + writerSettings.setEmptyValue(params.nullValue) + writerSettings.setSkipEmptyLines(true) + writerSettings.setQuoteAllFields(false) + writerSettings.setHeaders(headers: _*) + + def writeRow(row: Seq[String], includeHeader: Boolean): String = { + val buffer = new ByteArrayOutputStream() + val outputWriter = new OutputStreamWriter(buffer) + val writer = new CsvWriter(outputWriter, writerSettings) + + if (includeHeader) { + writer.writeHeaders() + } + writer.writeRow(row.toArray: _*) + writer.close() + buffer.toString.stripLineEnd + } +} + +/** + * Parser for parsing a line at a time. Not efficient for bulk data. + * + * @param params Parameters object + */ +private[sql] class LineCsvReader(params: CSVParameters) + extends CsvReader(params, null) { + /** + * parse a line + * + * @param line a String with no newline at the end + * @return array of strings where each string is a field in the CSV record + */ + def parseLine(line: String): Array[String] = { + parser.beginParsing(new StringReader(line)) + val parsed = parser.parseNext() + parser.stopParsing() + parsed + } +} + +/** + * Parser for parsing lines in bulk. Use this when efficiency is desired. + * + * @param iter iterator over lines in the file + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] class BulkCsvReader( + iter: Iterator[String], + params: CSVParameters, + headers: Seq[String]) + extends CsvReader(params, headers) with Iterator[Array[String]] { + + private val reader = new StringIteratorReader(iter) + parser.beginParsing(reader) + private var nextRecord = parser.parseNext() + + /** + * get the next parsed line. + * @return array of strings where each string is a field in the CSV record + */ + override def next(): Array[String] = { + val curRecord = nextRecord + if(curRecord != null) { + nextRecord = parser.parseNext() + } else { + throw new NoSuchElementException("next record is null") + } + curRecord + } + + override def hasNext: Boolean = nextRecord != null + +} + +/** + * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at + * end of each line Univocity parser requires a Reader that provides access to the data to be + * parsed and needs the newlines to be present + * @param iter iterator over RDD[String] + */ +private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { + + private var next: Long = 0 + private var length: Long = 0 // length of input so far + private var start: Long = 0 + private var str: String = null // current string from iter + + /** + * fetch next string from iter, if done with current one + * pretend there is a new line at the end of every string we get from from iter + */ + private def refill(): Unit = { + if (length == next) { + if (iter.hasNext) { + str = iter.next() + start = length + length += (str.length + 1) // allowance for newline removed by SparkContext.textFile() + } else { + str = null + } + } + } + + /** + * read the next character, if at end of string pretend there is a new line + */ + override def read(): Int = { + refill() + if (next >= length) { + -1 + } else { + val cur = next - start + next += 1 + if (cur == str.length) '\n' else str.charAt(cur.toInt) + } + } + + /** + * read from str into cbuf + */ + override def read(cbuf: Array[Char], off: Int, len: Int): Int = { + refill() + var n = 0 + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException() + } else if (len == 0) { + n = 0 + } else { + if (next >= length) { // end of input + n = -1 + } else { + n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size + if (n == length - next) { + str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off) + cbuf(off + n - 1) = '\n' + } else { + str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off) + } + next += n + if (n < len) { + val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter + if(m != -1) n += m + } + } + } + + n + } + + override def skip(ns: Long): Long = { + throw new IllegalArgumentException("Skip not implemented") + } + + override def ready: Boolean = { + refill() + true + } + + override def markSupported: Boolean = false + + override def mark(readAheadLimit: Int): Unit = { + throw new IllegalArgumentException("Mark not implemented") + } + + override def reset(): Unit = { + throw new IllegalArgumentException("Mark and hence reset not implemented") + } + + override def close(): Unit = { } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala new file mode 100644 index 0000000000..9267479755 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import scala.util.control.NonFatal + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.RecordWriter +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat + +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +private[csv] class CSVRelation( + private val inputRDD: Option[RDD[String]], + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + private val parameters: Map[String, String]) + (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable { + + override lazy val dataSchema: StructType = maybeDataSchema match { + case Some(structType) => structType + case None => inferSchema(paths) + } + + private val params = new CSVParameters(parameters) + + @transient + private var cachedRDD: Option[RDD[String]] = None + + private def readText(location: String): RDD[String] = { + if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { + sqlContext.sparkContext.textFile(location) + } else { + sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) + .mapPartitions { _.map { pair => + new String(pair._2.getBytes, 0, pair._2.getLength, params.charset) + } + } + } + } + + private def baseRdd(inputPaths: Array[String]): RDD[String] = { + inputRDD.getOrElse { + cachedRDD.getOrElse { + val rdd = readText(inputPaths.mkString(",")) + cachedRDD = Some(rdd) + rdd + } + } + } + + private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = { + val rdd = baseRdd(inputPaths) + // Make sure firstLine is materialized before sending to executors + val firstLine = if (params.headerFlag) findFirstLine(rdd) else null + CSVRelation.univocityTokenizer(rdd, header, firstLine, params) + } + + /** + * This supports to eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. This reads all the tokens of each line + * and then drop unneeded tokens without casting and type-checking by mapping + * both the indices produced by `requiredColumns` and the ones of tokens. + * TODO: Switch to using buildInternalScan + */ + override def buildScan(requiredColumns: Array[String], inputs: Array[FileStatus]): RDD[Row] = { + val pathsString = inputs.map(_.getPath.toUri.toString) + val header = schema.fields.map(_.name) + val tokenizedRdd = tokenRdd(header, pathsString) + CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params) + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new CSVOutputWriterFactory(params) + } + + override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns) + + override def equals(other: Any): Boolean = other match { + case that: CSVRelation => { + val equalPath = paths.toSet == that.paths.toSet + val equalDataSchema = dataSchema == that.dataSchema + val equalSchema = schema == that.schema + val equalPartitionColums = partitionColumns == that.partitionColumns + + equalPath && equalDataSchema && equalSchema && equalPartitionColums + } + case _ => false + } + + private def inferSchema(paths: Array[String]): StructType = { + val rdd = baseRdd(Array(paths.head)) + val firstLine = findFirstLine(rdd) + val firstRow = new LineCsvReader(params).parseLine(firstLine) + + val header = if (params.headerFlag) { + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"C$index" } + } + + val parsedRdd = tokenRdd(header, paths) + if (params.inferSchemaFlag) { + CSVInferSchema(parsedRdd, header, params.nullValue) + } else { + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName.toString, StringType, nullable = true) + } + StructType(schemaFields) + } + } + + /** + * Returns the first line of the first non-empty file in path + */ + private def findFirstLine(rdd: RDD[String]): String = { + if (params.isCommentSet) { + rdd.take(params.MAX_COMMENT_LINES_IN_HEADER) + .find(!_.startsWith(params.comment.toString)) + .getOrElse(sys.error(s"No uncommented header line in " + + s"first ${params.MAX_COMMENT_LINES_IN_HEADER} lines")) + } else { + rdd.first() + } + } +} + +object CSVRelation extends Logging { + + def univocityTokenizer( + file: RDD[String], + header: Seq[String], + firstLine: String, + params: CSVParameters): RDD[Array[String]] = { + // If header is set, make sure firstLine is materialized before sending to executors. + file.mapPartitionsWithIndex({ + case (split, iter) => new BulkCsvReader( + if (params.headerFlag) iter.filterNot(_ == firstLine) else iter, + params, + headers = header) + }, true) + } + + def parseCsv( + tokenizedRDD: RDD[Array[String]], + schema: StructType, + requiredColumns: Array[String], + inputs: Array[FileStatus], + sqlContext: SQLContext, + params: CSVParameters): RDD[Row] = { + + val schemaFields = schema.fields + val requiredFields = StructType(requiredColumns.map(schema(_))).fields + val safeRequiredFields = if (params.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredFields ++ schemaFields.filterNot(requiredFields.contains(_)) + } else { + requiredFields + } + if (requiredColumns.isEmpty) { + sqlContext.sparkContext.emptyRDD[Row] + } else { + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) + schemaFields.zipWithIndex.filter { + case (field, _) => safeRequiredFields.contains(field) + }.foreach { + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + } + val rowArray = new Array[Any](safeRequiredIndices.length) + val requiredSize = requiredFields.length + tokenizedRDD.flatMap { tokens => + if (params.dropMalformed && schemaFields.length != tokens.size) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } else if (params.failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(params.delimiter.toString)}") + } else { + val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { + tokens ++ new Array[String](schemaFields.length - tokens.size) + } else if (params.permissive && schemaFields.length < tokens.size) { + tokens.take(schemaFields.length) + } else { + tokens + } + try { + var index: Int = 0 + var subIndex: Int = 0 + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) + val field = schemaFields(index) + rowArray(subIndex) = CSVTypeCast.castTo( + indexSafeTokens(index), + field.dataType, + field.nullable, + params.nullValue) + subIndex = subIndex + 1 + } + Some(Row.fromSeq(rowArray.take(requiredSize))) + } catch { + case NonFatal(e) if params.dropMalformed => + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } + } + } + } + } +} + +private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, params) + } +} + +private[sql] class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVParameters) extends OutputWriter with Logging { + + // create the Generator without separator inserted between 2 records + private[this] val text = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + private var firstRow: Boolean = params.headerFlag + + private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + + private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => + if (field != null) { + field.toString + } else { + params.nullValue + } + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + // TODO: Instead of converting and writing every row, we should use the univocity buffer + val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow) + if (firstRow) { + firstRow = false + } + text.set(resultString) + recordWriter.write(NullWritable.get(), text) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala new file mode 100644 index 0000000000..2fffae452c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType + +/** + * Provides access to CSV data from pure SQL statements. + */ +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "csv" + + /** + * Creates a new relation for data store in CSV given parameters and user supported schema. + */ + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + + new CSVRelation( + None, + paths, + dataSchema, + partitionColumns, + parameters)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 59ba4ae2cb..44d5e4ff7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -145,7 +145,7 @@ private[json] object InferSchema { /** * Convert NullType to StringType and remove StructTypes with no fields */ - private def canonicalizeType: DataType => Option[DataType] = { + private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) @@ -154,15 +154,15 @@ private[json] object InferSchema { } case StructType(fields) => - val canonicalFields = for { + val canonicalFields: Array[StructField] = for { field <- fields - if field.name.nonEmpty + if field.name.length > 0 canonicalType <- canonicalizeType(field.dataType) } yield { field.copy(dataType = canonicalType) } - if (canonicalFields.nonEmpty) { + if (canonicalFields.length > 0) { Some(StructType(canonicalFields)) } else { // per SPARK-8093: empty structs should be deleted @@ -217,10 +217,9 @@ private[json] object InferSchema { (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. - case (DoubleType, t: DecimalType) => - DoubleType - case (t: DecimalType, DoubleType) => + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => DoubleType + case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/cars-alternative.csv new file mode 100644 index 0000000000..646f7c456c --- /dev/null +++ b/sql/core/src/test/resources/cars-alternative.csv @@ -0,0 +1,5 @@ +year|make|model|comment|blank +'2012'|'Tesla'|'S'| 'No comment'| + +1997|Ford|E350|'Go get one now they are going fast'| +2015|Chevy|Volt diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/cars-null.csv new file mode 100644 index 0000000000..130c0b40bb --- /dev/null +++ b/sql/core/src/test/resources/cars-null.csv @@ -0,0 +1,6 @@ +year,make,model,comment,blank +"2012","Tesla","S",null, + +1997,Ford,E350,"Go get one now they are going fast", +null,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/cars-unbalanced-quotes.csv new file mode 100644 index 0000000000..5ea39fcbfa --- /dev/null +++ b/sql/core/src/test/resources/cars-unbalanced-quotes.csv @@ -0,0 +1,4 @@ +year,make,model,comment,blank +"2012,Tesla,S,No comment +1997,Ford,E350,Go get one now they are going fast" +"2015,"Chevy",Volt, diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/cars.csv new file mode 100644 index 0000000000..2b9d74ca60 --- /dev/null +++ b/sql/core/src/test/resources/cars.csv @@ -0,0 +1,6 @@ +year,make,model,comment,blank +"2012","Tesla","S","No comment", + +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/cars.tsv new file mode 100644 index 0000000000..a7bfa9a91f --- /dev/null +++ b/sql/core/src/test/resources/cars.tsv @@ -0,0 +1,4 @@ +year make model price comment blank +2012 Tesla S "80,000.65" +1997 Ford E350 35,000 "Go get one now they are going fast" +2015 Chevy Volt 5,000.10 diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/cars_iso-8859-1.csv new file mode 100644 index 0000000000..c51b6c5901 --- /dev/null +++ b/sql/core/src/test/resources/cars_iso-8859-1.csv @@ -0,0 +1,6 @@ +yearþmakeþmodelþcommentþblank +"2012"þ"Tesla"þ"S"þ"No comment"þ + +1997þFordþE350þ"Go get one now they are þoing fast"þ +2015þChevyþVolt + diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/comments.csv new file mode 100644 index 0000000000..6275be7285 --- /dev/null +++ b/sql/core/src/test/resources/comments.csv @@ -0,0 +1,6 @@ +~ Version 1.0 +~ Using a non-standard comment char to test CSV parser defaults are overridden +1,2,3,4,5.01,2015-08-20 15:57:00 +6,7,8,9,0,2015-08-21 16:58:01 +~0,9,8,7,6,2015-08-22 17:59:02 +1,2,3,4,5,2015-08-23 18:00:42 diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/disable_comments.csv new file mode 100644 index 0000000000..304d406e4d --- /dev/null +++ b/sql/core/src/test/resources/disable_comments.csv @@ -0,0 +1,2 @@ +#1,2,3 +4,5,6 diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/empty.csv new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/sql/core/src/test/resources/empty.csv diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala new file mode 100644 index 0000000000..a1796f1326 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class InferSchemaSuite extends SparkFunSuite { + + test("String fields types are inferred correctly from null types") { + assert(CSVInferSchema.inferField(NullType, "") == NullType) + assert(CSVInferSchema.inferField(NullType, null) == NullType) + assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType) + assert(CSVInferSchema.inferField(NullType, "60") == IntegerType) + assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) + assert(CSVInferSchema.inferField(NullType, "test") == StringType) + assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + } + + test("String fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(LongType, "test") == StringType) + assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + } + + test("Timestamp field types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) + } + + test("Type arrays are merged to highest common type") { + assert( + CSVInferSchema.mergeRowTypes(Array(StringType), + Array(DoubleType)).deep == Array(StringType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(IntegerType), + Array(LongType)).deep == Array(LongType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(DoubleType), + Array(LongType)).deep == Array(DoubleType).deep) + } + + test("Null fields are handled properly when a nullValue is specified") { + assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType) + assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType) + assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType) + assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) + assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) + assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala new file mode 100644 index 0000000000..c0c38c6787 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite + +/** + * test cases for StringIteratorReader + */ +class CSVParserSuite extends SparkFunSuite { + + private def readAll(iter: Iterator[String]) = { + val reader = new StringIteratorReader(iter) + var c: Int = -1 + val read = new scala.collection.mutable.StringBuilder() + do { + c = reader.read() + read.append(c.toChar) + } while (c != -1) + + read.dropRight(1).toString + } + + private def readBufAll(iter: Iterator[String], bufSize: Int) = { + val reader = new StringIteratorReader(iter) + val cbuf = new Array[Char](bufSize) + val read = new scala.collection.mutable.StringBuilder() + + var done = false + do { // read all input one cbuf at a time + var numRead = 0 + var n = 0 + do { // try to fill cbuf + var off = 0 + var len = cbuf.length + n = reader.read(cbuf, off, len) + + if (n != -1) { + off += n + len -= n + } + + assert(len >= 0 && len <= cbuf.length) + assert(off >= 0 && off <= cbuf.length) + read.appendAll(cbuf.take(n)) + } while (n > 0) + if(n != -1) { + numRead += n + } else { + done = true + } + } while (!done) + + read.toString + } + + test("Hygiene") { + val reader = new StringIteratorReader(List("").toIterator) + assert(reader.ready === true) + assert(reader.markSupported === false) + intercept[IllegalArgumentException] { reader.skip(1) } + intercept[IllegalArgumentException] { reader.mark(1) } + intercept[IllegalArgumentException] { reader.reset() } + } + + test("Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ ("\n")) + } + + test("Empty iter") { + val input = List[String]() + val read = readAll(input.toIterator) + assert(read === "") + } + + test("Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ ("\n")) + } + + test("Buffer Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val output = input.mkString("\n") ++ ("\n") + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, i) + assert(read === output) + } + } + + test("Buffer Empty iter") { + val input = List[String]() + val output = "" + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === "") + } + } + + test("Buffer Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val output = input.mkString("\n") ++ ("\n") + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === output) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala new file mode 100644 index 0000000000..8fdd31aa43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File +import java.nio.charset.UnsupportedCharsetException +import java.sql.Timestamp + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + private val carsFile = "cars.csv" + private val carsFile8859 = "cars_iso-8859-1.csv" + private val carsTsvFile = "cars.tsv" + private val carsAltFile = "cars-alternative.csv" + private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" + private val carsNullFile = "cars-null.csv" + private val emptyFile = "empty.csv" + private val commentsFile = "comments.csv" + private val disableCommentsFile = "disable_comments.csv" + + private def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + + /** Verifies data and schema. */ + private def verifyCars( + df: DataFrame, + withHeader: Boolean, + numCars: Int = 3, + numFields: Int = 5, + checkHeader: Boolean = true, + checkValues: Boolean = true, + checkTypes: Boolean = false): Unit = { + + val numColumns = numFields + val numRows = if (withHeader) numCars else numCars + 1 + // schema + assert(df.schema.fieldNames.length === numColumns) + assert(df.collect().length === numRows) + + if (checkHeader) { + if (withHeader) { + assert(df.schema.fieldNames === Array("year", "make", "model", "comment", "blank")) + } else { + assert(df.schema.fieldNames === Array("C0", "C1", "C2", "C3", "C4")) + } + } + + if (checkValues) { + val yearValues = List("2012", "1997", "2015") + val actualYears = if (!withHeader) "year" :: yearValues else yearValues + val years = if (withHeader) df.select("year").collect() else df.select("C0").collect() + + years.zipWithIndex.foreach { case (year, index) => + if (checkTypes) { + assert(year === Row(actualYears(index).toInt)) + } else { + assert(year === Row(actualYears(index))) + } + } + } + } + + test("simple csv test") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("simple csv test with type inference") { + val cars = sqlContext + .read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = true, checkTypes = true) + } + + test("test with alternative delimiter and quote") { + val cars = sqlContext.read + .format("csv") + .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) + .load(testFile(carsAltFile)) + + verifyCars(cars, withHeader = true) + } + + test("bad encoding name") { + val exception = intercept[UnsupportedCharsetException] { + sqlContext + .read + .format("csv") + .option("charset", "1-9588-osi") + .load(testFile(carsFile8859)) + } + + assert(exception.getMessage.contains("1-9588-osi")) + } + + ignore("test different encoding") { + // scalastyle:off + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsFile8859)}", header "true", + |charset "iso-8859-1", delimiter "þ") + """.stripMargin.replaceAll("\n", " ")) + //scalstyle:on + + verifyCars(sqlContext.table("carsTable"), withHeader = true) + } + + test("DDL test with tab separated file") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + } + + test("DDL test parsing decimal type") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, priceTag decimal, + | comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + assert( + sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + } + + test("test for DROPMALFORMED parsing mode") { + val cars = sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) + + assert(cars.select("year").collect().size === 2) + } + + test("test for FAILFAST parsing mode") { + val exception = intercept[SparkException]{ + sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } + + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } + + test("test with null quote character") { + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "") + .load(testFile(carsUnbalancedQuotesFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + + } + + test("test with empty file and known schema") { + val result = sqlContext.read + .format("csv") + .schema(StructType(List(StructField("column", StringType, false)))) + .load(testFile(emptyFile)) + + assert(result.collect.size === 0) + assert(result.schema.fieldNames.size === 1) + } + + + test("DDL test with empty file") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(emptyFile)}", header "false") + """.stripMargin.replaceAll("\n", " ")) + + assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + } + + test("DDL test with schema") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, blank string) + |USING csv + |OPTIONS (path "${testFile(carsFile)}", header "true") + """.stripMargin.replaceAll("\n", " ")) + + val cars = sqlContext.table("carsTable") + verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) + assert( + cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) + } + + test("save csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .save(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("save csv with quote") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("quote", "\"") + .save(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "\"") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), + Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), + Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("inferring schema with commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq(1, 2, 3, 4, 5.01D, Timestamp.valueOf("2015-08-20 15:57:00")), + Seq(6, 7, 8, 9, 0, Timestamp.valueOf("2015-08-21 16:58:01")), + Seq(1, 2, 3, 4, 5, Timestamp.valueOf("2015-08-23 18:00:42"))) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("setting comment to null disables comment support") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "", "header" -> "false")) + .load(testFile(disableCommentsFile)) + .collect() + + val expected = + Seq( + Seq("#1", "2", "3"), + Seq("4", "5", "6")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("nullable fields with user defined null value of \"null\"") { + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = sqlContext.read + .format("csv") + .schema(dataSchema) + .options(Map("header" -> "true", "nullValue" -> "null")) + .load(testFile(carsNullFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala new file mode 100644 index 0000000000..40c5ccd0f7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class CSVTypeCastSuite extends SparkFunSuite { + + test("Can parse decimal type values") { + val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") + val decimalValues = Seq(10.05, 1000.01, 158058049.001) + val decimalType = new DecimalType() + + stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => + assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString)) + } + } + + test("Can parse escaped characters") { + assert(CSVTypeCast.toChar("""\t""") === '\t') + assert(CSVTypeCast.toChar("""\r""") === '\r') + assert(CSVTypeCast.toChar("""\b""") === '\b') + assert(CSVTypeCast.toChar("""\f""") === '\f') + assert(CSVTypeCast.toChar("""\"""") === '\"') + assert(CSVTypeCast.toChar("""\'""") === '\'') + assert(CSVTypeCast.toChar("""\u0000""") === '\u0000') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + + test("Nullable types are handled") { + assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null) + } + + test("String type should always return the same as the input") { + assert(CSVTypeCast.castTo("", StringType, nullable = true) == "") + assert(CSVTypeCast.castTo("", StringType, nullable = false) == "") + } + + test("Throws exception for empty string with non null type") { + val exception = intercept[NumberFormatException]{ + CSVTypeCast.castTo("", IntegerType, nullable = false) + } + assert(exception.getMessage.contains("For input string: \"\"")) + } + + test("Types are cast correctly") { + assert(CSVTypeCast.castTo("10", ByteType) == 10) + assert(CSVTypeCast.castTo("10", ShortType) == 10) + assert(CSVTypeCast.castTo("10", IntegerType) == 10) + assert(CSVTypeCast.castTo("10", LongType) == 10) + assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) + assert(CSVTypeCast.castTo("true", BooleanType) == true) + val timestamp = "2015-01-01 00:00:00" + assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) + assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + } + + test("Float and Double Types are cast correctly with Locale") { + val locale : Locale = new Locale("fr", "FR") + Locale.setDefault(locale) + assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + } +} |