From 455129020ca7f6a162f6f2486a87cc43512cfd2c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 8 Mar 2017 13:43:09 -0800 Subject: [SPARK-15463][SQL] Add an API to load DataFrame from Dataset[String] storing CSV ## What changes were proposed in this pull request? This PR proposes to add an API that loads `DataFrame` from `Dataset[String]` storing csv. It allows pre-processing before loading into CSV, which means allowing a lot of workarounds for many narrow cases, for example, as below: - Case 1 - pre-processing ```scala val df = spark.read.text("...") // Pre-processing with this. spark.read.csv(df.as[String]) ``` - Case 2 - use other input formats ```scala val rdd = spark.sparkContext.newAPIHadoopFile("/file.csv.lzo", classOf[com.hadoop.mapreduce.LzoTextInputFormat], classOf[org.apache.hadoop.io.LongWritable], classOf[org.apache.hadoop.io.Text]) val stringRdd = rdd.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength)) spark.read.csv(stringRdd.toDS) ``` ## How was this patch tested? Added tests in `CSVSuite` and build with Scala 2.10. ``` ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon Closes #16854 from HyukjinKwon/SPARK-15463. --- .../org/apache/spark/sql/DataFrameReader.scala | 71 +++++++++++++++++++--- .../execution/datasources/csv/CSVDataSource.scala | 49 +++++++++------ .../sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 2 +- 4 files changed, 94 insertions(+), 30 deletions(-) (limited to 'sql/core/src/main/scala/org/apache') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41470ae6aa..a5e38e25b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema @@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { createParser) } - // Check a field requirement for corrupt records here to throw an exception in a driver side - schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) @@ -398,6 +392,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { csv(Seq(path): _*) } + /** + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * @param csvDataset input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: Dataset[String]): DataFrame = { + val parsedOptions: CSVOptions = new CSVOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone) + val filteredLines: Dataset[String] = + CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine: Option[String] = filteredLines.take(1).headOption + + val schema = userSpecifiedSchema.getOrElse { + TextInputCSVDataSource.inferFromDataset( + sparkSession, + csvDataset, + maybeFirstLine, + parsedOptions) + } + + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) + }.getOrElse(filteredLines.rdd) + + val parsed = linesWithoutHeader.mapPartitions { iter => + val parser = new UnivocityParser(schema, parsedOptions) + iter.flatMap(line => parser.parse(line)) + } + + Dataset.ofRows( + sparkSession, + LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + } + /** * Loads a CSV file and returns the result as a `DataFrame`. * @@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 47567032b0..35ff924f27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.InputStream import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - case None => - // If the first line could not be read, just return the empty schema. - Some(StructType(Nil)) - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + } + + /** + * Infers the schema from `Dataset` that stores CSV string records. + */ + def inferFromDataset( + sparkSession: SparkSession, + csv: Dataset[String], + maybeFirstLine: Option[String], + parsedOptions: CSVOptions): StructType = maybeFirstLine match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first line could not be read, just return the empty schema. + StructType(Nil) } private def createBaseDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 50503385ad..0b1e5dac2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} -private[csv] class CSVOptions( +class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3b3b87e435..e42ea3fa39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[csv] class UnivocityParser( +class UnivocityParser( schema: StructType, requiredSchema: StructType, private val options: CSVOptions) extends Logging { -- cgit v1.2.3