From 73b56a3c6c5c590219b42884c8bbe88b0a236987 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 8 Apr 2016 00:30:26 -0700 Subject: [SPARK-14189][SQL] JSON data sources find compatible types even if inferred decimal type is not capable of the others ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14189 When inferred types in the same field during finding compatible `DataType`, are `IntegralType` and `DecimalType` but `DecimalType` is not capable of the given `IntegralType`, JSON data source simply fails to find a compatible type resulting in `StringType`. This can be observed when `prefersDecimal` is enabled. ```scala def mixedIntegerAndDoubleRecords: RDD[String] = sqlContext.sparkContext.parallelize( """{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 1}""" :: Nil) val jsonDF = sqlContext.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) .printSchema() ``` - **Before** ``` root |-- a: string (nullable = true) |-- b: string (nullable = true) ``` - **After** ``` root |-- a: decimal(21, 1) (nullable = true) |-- b: decimal(21, 1) (nullable = true) ``` (Note that integer is inferred as `LongType` which becomes `DecimalType(20, 0)`) ## How was this patch tested? unit tests were used and style tests by `dev/run_tests`. Author: hyukjinkwon Closes #11993 from HyukjinKwon/SPARK-14189. --- .../execution/datasources/json/InferSchema.scala | 8 ++++++++ .../sql/execution/datasources/json/JsonSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 4a34f365e4..8e8238a594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -256,6 +256,14 @@ private[sql] object InferSchema { case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when + // the given `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + // strings and every string is a Json object. case (_, _) => StringType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 421862c394..2a18acb95b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -773,6 +773,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { + val mixedIntegerAndDoubleRecords = sparkContext.parallelize( + """{"a": 3, "b": 1.1}""" :: + s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) + val jsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(mixedIntegerAndDoubleRecords) + + // The values in `a` field will be decimals as they fit in decimal. For `b` field, + // they will be doubles as `1.0E-39D` does not fit. + val expectedSchema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DoubleType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer( + jsonDF, + Row(BigDecimal("3"), 1.1D) :: + Row(BigDecimal("3.1"), 1.0E-39D) :: Nil + ) + } + test("Infer big integers correctly even when it does not fit in decimal") { val jsonDF = sqlContext.read .json(bigIntegerRecords) -- cgit v1.2.3