From 8577260abdc908ac08d28ddd3f07a2411fdc82b7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 Mar 2016 14:32:01 -0800 Subject: [SPARK-13442][SQL] Make type inference recognize boolean types ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13442 This PR adds the support for inferring `BooleanType` for schema. It supports to infer case-insensitive `true` / `false` as `BooleanType`. Unittests were added for `CSVInferSchemaSuite` and `CSVSuite` for end-to-end test. ## How was the this patch tested? This was tested with unittests and with `dev/run_tests` for coding style Author: hyukjinkwon Closes #11315 from HyukjinKwon/SPARK-13442. --- .../sql/execution/datasources/csv/CSVInferSchema.scala | 9 +++++++++ sql/core/src/test/resources/bool.csv | 5 +++++ .../sql/execution/datasources/csv/CSVInferSchemaSuite.scala | 11 +++++++++++ .../spark/sql/execution/datasources/csv/CSVSuite.scala | 13 +++++++++++++ 4 files changed, 38 insertions(+) create mode 100644 sql/core/src/test/resources/bool.csv (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 7f1ed28046..edead9b21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema { case LongType => tryParseLong(field) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -117,6 +118,14 @@ private[csv] object CSVInferSchema { def tryParseTimestamp(field: String): DataType = { if ((allCatch opt Timestamp.valueOf(field)).isDefined) { TimestampType + } else { + tryParseBoolean(field) + } + } + + def tryParseBoolean(field: String): DataType = { + if ((allCatch opt field.toBoolean).isDefined) { + BooleanType } else { stringType() } diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/bool.csv new file mode 100644 index 0000000000..94b2d49506 --- /dev/null +++ b/sql/core/src/test/resources/bool.csv @@ -0,0 +1,5 @@ +bool +"True" +"False" + +"true" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 412f1b89be..7af3f94aef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) assert(CSVInferSchema.inferField(NullType, "test") == StringType) assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True") == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType) } test("String fields types are inferred correctly from other types") { @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True") == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType) } test("Timestamp field types are inferred correctly from other types") { @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } + test("Boolean fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "Fale") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType) + } + test("Type arrays are merged to highest common type") { assert( CSVInferSchema.mergeRowTypes(Array(StringType), @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) } test("Merging Nulltypes should yeild Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 9cd3a9ab95..53027bb698 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { @@ -118,6 +119,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("test inferring booleans") { + val result = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(boolFile)) + + val expectedSchema = StructType(List( + StructField("bool", BooleanType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { val cars = sqlContext.read .format("csv") -- cgit v1.2.3