diff options
author | hyukjinkwon <gurwls223@gmail.com> | 2016-03-07 14:32:01 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-03-07 14:32:01 -0800 |
commit | 8577260abdc908ac08d28ddd3f07a2411fdc82b7 (patch) | |
tree | 5832a628589096ef79bb6706a185ab26a5e416a4 /sql | |
parent | e1fb857992074164dcaa02498c5a9604fac6f57e (diff) | |
download | spark-8577260abdc908ac08d28ddd3f07a2411fdc82b7.tar.gz spark-8577260abdc908ac08d28ddd3f07a2411fdc82b7.tar.bz2 spark-8577260abdc908ac08d28ddd3f07a2411fdc82b7.zip |
[SPARK-13442][SQL] Make type inference recognize boolean types
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-13442
This PR adds the support for inferring `BooleanType` for schema.
It supports to infer case-insensitive `true` / `false` as `BooleanType`.
Unittests were added for `CSVInferSchemaSuite` and `CSVSuite` for end-to-end test.
## How was the this patch tested?
This was tested with unittests and with `dev/run_tests` for coding style
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #11315 from HyukjinKwon/SPARK-13442.
Diffstat (limited to 'sql')
4 files changed, 38 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 7f1ed28046..edead9b21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema { case LongType => tryParseLong(field) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -118,6 +119,14 @@ private[csv] object CSVInferSchema { if ((allCatch opt Timestamp.valueOf(field)).isDefined) { TimestampType } else { + tryParseBoolean(field) + } + } + + def tryParseBoolean(field: String): DataType = { + if ((allCatch opt field.toBoolean).isDefined) { + BooleanType + } else { stringType() } } diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/bool.csv new file mode 100644 index 0000000000..94b2d49506 --- /dev/null +++ b/sql/core/src/test/resources/bool.csv @@ -0,0 +1,5 @@ +bool +"True" +"False" + +"true" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 412f1b89be..7af3f94aef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) assert(CSVInferSchema.inferField(NullType, "test") == StringType) assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True") == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType) } test("String fields types are inferred correctly from other types") { @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True") == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType) } test("Timestamp field types are inferred correctly from other types") { @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } + test("Boolean fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "Fale") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType) + } + test("Type arrays are merged to highest common type") { assert( CSVInferSchema.mergeRowTypes(Array(StringType), @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) } test("Merging Nulltypes should yeild Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 9cd3a9ab95..53027bb698 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { @@ -118,6 +119,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("test inferring booleans") { + val result = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(boolFile)) + + val expectedSchema = StructType(List( + StructField("bool", BooleanType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { val cars = sqlContext.read .format("csv") |