From dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1 Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Sun, 28 Feb 2016 23:16:34 -0800 Subject: [SPARK-13309][SQL] Fix type inference issue with CSV data Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309 Author: Rahul Tanwani Closes #11194 from tanwanirahul/master. --- .../sql/execution/datasources/csv/CSVInferSchema.scala | 18 +++++++++--------- sql/core/src/test/resources/simple_sparse.csv | 5 +++++ .../datasources/csv/CSVInferSchemaSuite.scala | 5 +++++ .../spark/sql/execution/datasources/csv/CSVSuite.scala | 14 +++++++++++++- 4 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/resources/simple_sparse.csv (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ace8cd7ad8..7f1ed28046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 0000000000..02d29cabf9 --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index a1796f1326..412f1b89be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) } + + test("Merging Nulltypes should yeild Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7671bc1066..5d57d77ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } -- cgit v1.2.3