From 3adebfc9a37fdee5b7a4e891c4ee597b85f824c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jan 2016 00:57:56 -0800 Subject: [SPARK-12901][SQL] Refactor options for JSON and CSV datasource (not case class and same format). https://issues.apache.org/jira/browse/SPARK-12901 This PR refactors the options in JSON and CSV datasources. In more details, 1. `JSONOptions` uses the same format as `CSVOptions`. 2. Not case classes. 3. `CSVRelation` that does not have to be serializable (it was `with Serializable` but I removed) Author: hyukjinkwon Closes #10895 from HyukjinKwon/SPARK-12901. --- .../sql/execution/datasources/csv/CSVOptions.scala | 116 ++++++++++++++++++++ .../execution/datasources/csv/CSVParameters.scala | 117 --------------------- .../sql/execution/datasources/csv/CSVParser.scala | 8 +- .../execution/datasources/csv/CSVRelation.scala | 12 +-- .../execution/datasources/json/JSONOptions.scala | 59 +++++------ .../execution/datasources/json/JSONRelation.scala | 2 +- .../sql/execution/datasources/json/JsonSuite.scala | 4 +- 7 files changed, 153 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala new file mode 100644 index 0000000000..5d0e99d760 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import org.apache.spark.Logging +import org.apache.spark.sql.execution.datasources.CompressionCodecs + +private[sql] class CSVOptions( + @transient parameters: Map[String, String]) + extends Logging with Serializable { + + private def getChar(paramName: String, default: Char): Char = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) if value.length == 0 => '\u0000' + case Some(value) if value.length == 1 => value.charAt(0) + case _ => throw new RuntimeException(s"$paramName cannot be more than one character") + } + } + + private def getBool(paramName: String, default: Boolean = false): Boolean = { + val param = parameters.getOrElse(paramName, default.toString) + if (param.toLowerCase == "true") { + true + } else if (param.toLowerCase == "false") { + false + } else { + throw new Exception(s"$paramName flag can be true or false") + } + } + + val delimiter = CSVTypeCast.toChar( + parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", Charset.forName("UTF-8").name())) + + val quote = getChar("quote", '\"') + val escape = getChar("escape", '\\') + val comment = getChar("comment", '\u0000') + + val headerFlag = getBool("header") + val inferSchemaFlag = getBool("inferSchema") + val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") + val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + + // Limit the number of lines we'll search for a header row that isn't comment-prefixed + val MAX_COMMENT_LINES_IN_HEADER = 10 + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) + + val nullValue = parameters.getOrElse("nullValue", "") + + val compressionCodec: Option[String] = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } + + val maxColumns = 20480 + + val maxCharsPerColumn = 100000 + + val inputBufferSize = 128 + + val isCommentSet = this.comment != '\u0000' + + val rowSeparator = "\n" +} + +private[csv] object ParseModes { + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala deleted file mode 100644 index 0278675aa6..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.nio.charset.Charset - -import org.apache.hadoop.io.compress._ - -import org.apache.spark.Logging -import org.apache.spark.sql.execution.datasources.CompressionCodecs -import org.apache.spark.util.Utils - -private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { - - private def getChar(paramName: String, default: Char): Char = { - val paramValue = parameters.get(paramName) - paramValue match { - case None => default - case Some(value) if value.length == 0 => '\u0000' - case Some(value) if value.length == 1 => value.charAt(0) - case _ => throw new RuntimeException(s"$paramName cannot be more than one character") - } - } - - private def getBool(paramName: String, default: Boolean = false): Boolean = { - val param = parameters.getOrElse(paramName, default.toString) - if (param.toLowerCase == "true") { - true - } else if (param.toLowerCase == "false") { - false - } else { - throw new Exception(s"$paramName flag can be true or false") - } - } - - val delimiter = CSVTypeCast.toChar( - parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", Charset.forName("UTF-8").name())) - - val quote = getChar("quote", '\"') - val escape = getChar("escape", '\\') - val comment = getChar("comment", '\u0000') - - val headerFlag = getBool("header") - val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") - val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") - - // Limit the number of lines we'll search for a header row that isn't comment-prefixed - val MAX_COMMENT_LINES_IN_HEADER = 10 - - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - - val nullValue = parameters.getOrElse("nullValue", "") - - val compressionCodec: Option[String] = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } - - val maxColumns = 20480 - - val maxCharsPerColumn = 100000 - - val inputBufferSize = 128 - - val isCommentSet = this.comment != '\u0000' - - val rowSeparator = "\n" -} - -private[csv] object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" - val DROP_MALFORMED_MODE = "DROPMALFORMED" - val FAIL_FAST_MODE = "FAILFAST" - - val DEFAULT = PERMISSIVE_MODE - - def isValidMode(mode: String): Boolean = { - mode.toUpperCase match { - case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true - case _ => false - } - } - - def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE - def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE - def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { - mode.toUpperCase == PERMISSIVE_MODE - } else { - true // We default to permissive is the mode string is not valid - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index ba1cc42f3e..8f1421844c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -29,7 +29,7 @@ import org.apache.spark.Logging * @param params Parameters object * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) { +private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { val settings = new CsvParserSettings() @@ -58,7 +58,7 @@ private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String * @param params Parameters object for configuration * @param headers headers for columns */ -private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging { +private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -93,7 +93,7 @@ private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) ex * * @param params Parameters object */ -private[sql] class LineCsvReader(params: CSVParameters) +private[sql] class LineCsvReader(params: CSVOptions) extends CsvReader(params, null) { /** * parse a line @@ -118,7 +118,7 @@ private[sql] class LineCsvReader(params: CSVParameters) */ private[sql] class BulkCsvReader( iter: Iterator[String], - params: CSVParameters, + params: CSVOptions, headers: Seq[String]) extends CsvReader(params, headers) with Iterator[Array[String]] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 1502501c3b..5959f7cc50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -43,14 +43,14 @@ private[csv] class CSVRelation( private val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], private val parameters: Map[String, String]) - (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable { + (@transient val sqlContext: SQLContext) extends HadoopFsRelation { override lazy val dataSchema: StructType = maybeDataSchema match { case Some(structType) => structType case None => inferSchema(paths) } - private val params = new CSVParameters(parameters) + private val params = new CSVOptions(parameters) @transient private var cachedRDD: Option[RDD[String]] = None @@ -170,7 +170,7 @@ object CSVRelation extends Logging { file: RDD[String], header: Seq[String], firstLine: String, - params: CSVParameters): RDD[Array[String]] = { + params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. file.mapPartitionsWithIndex({ case (split, iter) => new BulkCsvReader( @@ -186,7 +186,7 @@ object CSVRelation extends Logging { requiredColumns: Array[String], inputs: Array[FileStatus], sqlContext: SQLContext, - params: CSVParameters): RDD[Row] = { + params: CSVOptions): RDD[Row] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields @@ -249,7 +249,7 @@ object CSVRelation extends Logging { } } -private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory { +private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, @@ -262,7 +262,7 @@ private[sql] class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, - params: CSVParameters) extends OutputWriter with Logging { + params: CSVOptions) extends OutputWriter with Logging { // create the Generator without separator inserted between 2 records private[this] val text = new Text() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index e74a76c532..0a083b5e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -26,16 +26,30 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs * * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ -case class JSONOptions( - samplingRatio: Double = 1.0, - primitivesAsString: Boolean = false, - allowComments: Boolean = false, - allowUnquotedFieldNames: Boolean = false, - allowSingleQuotes: Boolean = true, - allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false, - allowBackslashEscapingAnyCharacter: Boolean = false, - compressionCodec: Option[String] = None) { +private[sql] class JSONOptions( + @transient parameters: Map[String, String]) + extends Serializable { + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + val allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + val allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + val allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + val allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + val allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val compressionCodec = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -48,28 +62,3 @@ case class JSONOptions( allowBackslashEscapingAnyCharacter) } } - -object JSONOptions { - def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( - samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), - primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), - allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false), - allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), - allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), - allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), - allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), - allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false), - compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } - ) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 93727abcc7..c893558136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -75,7 +75,7 @@ private[sql] class JSONRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) + val options: JSONOptions = new JSONOptions(parameters) /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d22fa7905a..00eaeb0d34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1240,7 +1240,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", JSONOptions()) + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } @@ -1264,7 +1264,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) + val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } -- cgit v1.2.3