aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorRahul Tanwani <rahul@Rahuls-MacBook-Pro.local>2016-02-28 23:16:34 -0800
committerReynold Xin <rxin@databricks.com>2016-02-28 23:16:34 -0800
commitdd3b5455c61bddce96a94c2ce8f5d76ed4948ea1 (patch)
tree0e68b30bd141b31197a9abd8e90b398454087b71 /sql
parent6dfc4a764c8bcfc24d951239835015da3ed7c29e (diff)
downloadspark-dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1.tar.gz
spark-dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1.tar.bz2
spark-dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1.zip
[SPARK-13309][SQL] Fix type inference issue with CSV data
Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309 Author: Rahul Tanwani <rahul@Rahuls-MacBook-Pro.local> Closes #11194 from tanwanirahul/master.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala18
-rw-r--r--sql/core/src/test/resources/simple_sparse.csv5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala14
4 files changed, 32 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index ace8cd7ad8..7f1ed28046 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._
-
private[csv] object CSVInferSchema {
/**
@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
- StructField(thisHeader, rootType, nullable = true)
+ val dType = rootType match {
+ case _: NullType => StringType
+ case other => other
+ }
+ StructField(thisHeader, dType, nullable = true)
}
StructType(structFields)
@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
}
def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
- first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
- val tpe = findTightestCommonType(a, b).getOrElse(StringType)
- tpe match {
- case _: NullType => StringType
- case other => other
- }
+ first.zipAll(second, NullType, NullType).map { case (a, b) =>
+ findTightestCommonType(a, b).getOrElse(NullType)
}
}
@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
+ case (StringType, t2) => Some(StringType)
+ case (t1, StringType) => Some(StringType)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
}
}
-
private[csv] object CSVTypeCast {
/**
diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv
new file mode 100644
index 0000000000..02d29cabf9
--- /dev/null
+++ b/sql/core/src/test/resources/simple_sparse.csv
@@ -0,0 +1,5 @@
+A,B,C,D
+1,,,
+,1,,
+,,1,
+,,,1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index a1796f1326..412f1b89be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}
+
+ test("Merging Nulltypes should yeild Nulltype.") {
+ val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
+ assert(mergedNullTypes.deep == Array(NullType).deep)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 7671bc1066..5d57d77ab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
+ private val simpleSparseFile = "simple_sparse.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema.fieldNames.size === 1)
}
-
test("DDL test with empty file") {
sqlContext.sql(s"""
|CREATE TEMPORARY TABLE carsTable
@@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(carsCopy, withHeader = true)
}
}
+
+ test("Schema inference correctly identifies the datatype when data is sparse.") {
+ val df = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .load(testFile(simpleSparseFile))
+
+ assert(
+ df.schema.fields.map(field => field.dataType).deep ==
+ Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
+ }
}