aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2017-03-08 13:43:09 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-08 13:43:09 -0800
commit455129020ca7f6a162f6f2486a87cc43512cfd2c (patch)
tree83262a1811ff2665ff03613eb7e2ad1de356a6c2 /sql/core/src/main/scala/org/apache
parent6570cfd7abe349dc6d2151f2ac9dc662e7465a79 (diff)
downloadspark-455129020ca7f6a162f6f2486a87cc43512cfd2c.tar.gz
spark-455129020ca7f6a162f6f2486a87cc43512cfd2c.tar.bz2
spark-455129020ca7f6a162f6f2486a87cc43512cfd2c.zip
[SPARK-15463][SQL] Add an API to load DataFrame from Dataset[String] storing CSV
## What changes were proposed in this pull request? This PR proposes to add an API that loads `DataFrame` from `Dataset[String]` storing csv. It allows pre-processing before loading into CSV, which means allowing a lot of workarounds for many narrow cases, for example, as below: - Case 1 - pre-processing ```scala val df = spark.read.text("...") // Pre-processing with this. spark.read.csv(df.as[String]) ``` - Case 2 - use other input formats ```scala val rdd = spark.sparkContext.newAPIHadoopFile("/file.csv.lzo", classOf[com.hadoop.mapreduce.LzoTextInputFormat], classOf[org.apache.hadoop.io.LongWritable], classOf[org.apache.hadoop.io.Text]) val stringRdd = rdd.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength)) spark.read.csv(stringRdd.toDS) ``` ## How was this patch tested? Added tests in `CSVSuite` and build with Scala 2.10. ``` ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #16854 from HyukjinKwon/SPARK-15463.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala71
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala2
4 files changed, 94 insertions, 30 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 41470ae6aa..a5e38e25b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
@@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
createParser)
}
- // Check a field requirement for corrupt records here to throw an exception in a driver side
- schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
- val f = schema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
- throw new AnalysisException(
- "The field for corrupt records must be string type and nullable")
- }
- }
+ verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
@@ -399,6 +393,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
+ * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+ * this function goes through the input once to determine the input schema.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
+ * it determines the columns as string types and it reads only the first line to determine the
+ * names and the number of fields.
+ *
+ * @param csvDataset input Dataset with one CSV row per record
+ * @since 2.2.0
+ */
+ def csv(csvDataset: Dataset[String]): DataFrame = {
+ val parsedOptions: CSVOptions = new CSVOptions(
+ extraOptions.toMap,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val filteredLines: Dataset[String] =
+ CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions)
+ val maybeFirstLine: Option[String] = filteredLines.take(1).headOption
+
+ val schema = userSpecifiedSchema.getOrElse {
+ TextInputCSVDataSource.inferFromDataset(
+ sparkSession,
+ csvDataset,
+ maybeFirstLine,
+ parsedOptions)
+ }
+
+ verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+
+ val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+ filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
+ }.getOrElse(filteredLines.rdd)
+
+ val parsed = linesWithoutHeader.mapPartitions { iter =>
+ val parser = new UnivocityParser(schema, parsedOptions)
+ iter.flatMap(line => parser.parse(line))
+ }
+
+ Dataset.ofRows(
+ sparkSession,
+ LogicalRDD(schema.toAttributes, parsed)(sparkSession))
+ }
+
+ /**
* Loads a CSV file and returns the result as a `DataFrame`.
*
* This function will go through the input once to determine the input schema if `inferSchema`
@@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
}
+ /**
+ * A convenient function for schema validation in datasources supporting
+ * `columnNameOfCorruptRecord` as an option.
+ */
+ private def verifyColumnNameOfCorruptRecord(
+ schema: StructType,
+ columnNameOfCorruptRecord: String): Unit = {
+ schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
+ val f = schema(corruptFieldIndex)
+ if (f.dataType != StringType || !f.nullable) {
+ throw new AnalysisException(
+ "The field for corrupt records must be string type and nullable")
+ }
+ }
+ }
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 47567032b0..35ff924f27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.io.InputStream
import java.nio.charset.{Charset, StandardCharsets}
-import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
+import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.Job
@@ -134,23 +133,33 @@ object TextInputCSVDataSource extends CSVDataSource {
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): Option[StructType] = {
val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
- CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match {
- case Some(firstLine) =>
- val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
- val tokenRDD = csv.rdd.mapPartitions { iter =>
- val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
- val linesWithoutHeader =
- CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
- val parser = new CsvParser(parsedOptions.asParserSettings)
- linesWithoutHeader.map(parser.parseLine)
- }
- Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
- case None =>
- // If the first line could not be read, just return the empty schema.
- Some(StructType(Nil))
- }
+ val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
+ Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions))
+ }
+
+ /**
+ * Infers the schema from `Dataset` that stores CSV string records.
+ */
+ def inferFromDataset(
+ sparkSession: SparkSession,
+ csv: Dataset[String],
+ maybeFirstLine: Option[String],
+ parsedOptions: CSVOptions): StructType = maybeFirstLine match {
+ case Some(firstLine) =>
+ val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+ val linesWithoutHeader =
+ CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+ val parser = new CsvParser(parsedOptions.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+ CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+ case None =>
+ // If the first line could not be read, just return the empty schema.
+ StructType(Nil)
}
private def createBaseDataset(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 50503385ad..0b1e5dac2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
-private[csv] class CSVOptions(
+class CSVOptions(
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 3b3b87e435..e42ea3fa39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-private[csv] class UnivocityParser(
+class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
private val options: CSVOptions) extends Logging {