aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen De Gennaro <stepheng@realitymine.com>2015-10-26 19:55:10 -0700
committerYin Huai <yhuai@databricks.com>2015-10-26 19:55:10 -0700
commit82464fb2e02ca4e4d425017815090497b79dc93f (patch)
tree38092b8f33e55405a41fcc00512e57b08f5fc0d8
parentd4c397a64af4cec899fdaa3e617ed20333cc567d (diff)
downloadspark-82464fb2e02ca4e4d425017815090497b79dc93f.tar.gz
spark-82464fb2e02ca4e4d425017815090497b79dc93f.tar.bz2
spark-82464fb2e02ca4e4d425017815090497b79dc93f.zip
[SPARK-10947] [SQL] With schema inference from JSON into a Dataframe, add option to infer all primitive object types as strings
Currently, when a schema is inferred from a JSON file using sqlContext.read.json, the primitive object types are inferred as string, long, boolean, etc. However, if the inferred type is too specific (JSON obviously does not enforce types itself), this can cause issues with merging dataframe schemas. This pull request adds the option "primitivesAsString" to the JSON DataFrameReader which when true (defaults to false if not set) will infer all primitives as strings. Below is an example usage of this new functionality. ``` val jsonDf = sqlContext.read.option("primitivesAsString", "true").json(sampleJsonFile) scala> jsonDf.printSchema() root |-- bigInteger: string (nullable = true) |-- boolean: string (nullable = true) |-- double: string (nullable = true) |-- integer: string (nullable = true) |-- long: string (nullable = true) |-- null: string (nullable = true) |-- string: string (nullable = true) ``` Author: Stephen De Gennaro <stepheng@realitymine.com> Closes #9249 from stephend-realitymine/stephend-primitives.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala138
4 files changed, 171 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 824220d85e..6a194a443a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -256,8 +256,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def json(jsonRDD: RDD[String]): DataFrame = {
val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble
+ val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean
sqlContext.baseRelationToDataFrame(
- new JSONRelation(Some(jsonRDD), samplingRatio, userSpecifiedSchema, None, None)(sqlContext))
+ new JSONRelation(
+ Some(jsonRDD),
+ samplingRatio,
+ primitivesAsString,
+ userSpecifiedSchema,
+ None,
+ None)(sqlContext)
+ )
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index d0780028da..b9914c581a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -35,7 +35,8 @@ private[sql] object InferSchema {
def apply(
json: RDD[String],
samplingRatio: Double = 1.0,
- columnNameOfCorruptRecords: String): StructType = {
+ columnNameOfCorruptRecords: String,
+ primitivesAsString: Boolean = false): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) {
json
@@ -50,7 +51,7 @@ private[sql] object InferSchema {
try {
Utils.tryWithResource(factory.createParser(row)) { parser =>
parser.nextToken()
- inferField(parser)
+ inferField(parser, primitivesAsString)
}
} catch {
case _: JsonParseException =>
@@ -70,14 +71,14 @@ private[sql] object InferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
- private def inferField(parser: JsonParser): DataType = {
+ private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
case FIELD_NAME =>
parser.nextToken()
- inferField(parser)
+ inferField(parser, primitivesAsString)
case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
@@ -92,7 +93,10 @@ private[sql] object InferSchema {
case START_OBJECT =>
val builder = Seq.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
- builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
+ builder += StructField(
+ parser.getCurrentName,
+ inferField(parser, primitivesAsString),
+ nullable = true)
}
StructType(builder.result().sortBy(_.name))
@@ -103,11 +107,15 @@ private[sql] object InferSchema {
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
- elementType = compatibleType(elementType, inferField(parser))
+ elementType = compatibleType(elementType, inferField(parser, primitivesAsString))
}
ArrayType(elementType)
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType
+
+ case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType
+
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
parser.getNumberType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 794b889a93..5f104fca7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -52,14 +52,23 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+ val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
- new JSONRelation(None, samplingRatio, dataSchema, None, partitionColumns, paths)(sqlContext)
+ new JSONRelation(
+ None,
+ samplingRatio,
+ primitivesAsString,
+ dataSchema,
+ None,
+ partitionColumns,
+ paths)(sqlContext)
}
}
private[sql] class JSONRelation(
val inputRDD: Option[RDD[String]],
val samplingRatio: Double,
+ val primitivesAsString: Boolean,
val maybeDataSchema: Option[StructType],
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
@@ -105,7 +114,8 @@ private[sql] class JSONRelation(
InferSchema(
inputRDD.getOrElse(createBaseRdd(files)),
samplingRatio,
- sqlContext.conf.columnNameOfCorruptRecord)
+ sqlContext.conf.columnNameOfCorruptRecord,
+ primitivesAsString)
}
checkConstraints(jsonSchema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 7540223bf2..d3fd409291 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -632,6 +632,136 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
+ test("Loading a JSON dataset primitivesAsString returns schema with primitive types as strings") {
+ val dir = Utils.createTempDir()
+ dir.delete()
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
+ val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path)
+
+ val expectedSchema = StructType(
+ StructField("bigInteger", StringType, true) ::
+ StructField("boolean", StringType, true) ::
+ StructField("double", StringType, true) ::
+ StructField("integer", StringType, true) ::
+ StructField("long", StringType, true) ::
+ StructField("null", StringType, true) ::
+ StructField("string", StringType, true) :: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+
+ jsonDF.registerTempTable("jsonTable")
+
+ checkAnswer(
+ sql("select * from jsonTable"),
+ Row("92233720368547758070",
+ "true",
+ "1.7976931348623157E308",
+ "10",
+ "21474836470",
+ null,
+ "this is a simple string.")
+ )
+ }
+
+ test("Loading a JSON dataset primitivesAsString returns complex fields as strings") {
+ val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1)
+
+ val expectedSchema = StructType(
+ StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
+ StructField("arrayOfArray2", ArrayType(ArrayType(StringType, true), true), true) ::
+ StructField("arrayOfBigInteger", ArrayType(StringType, true), true) ::
+ StructField("arrayOfBoolean", ArrayType(StringType, true), true) ::
+ StructField("arrayOfDouble", ArrayType(StringType, true), true) ::
+ StructField("arrayOfInteger", ArrayType(StringType, true), true) ::
+ StructField("arrayOfLong", ArrayType(StringType, true), true) ::
+ StructField("arrayOfNull", ArrayType(StringType, true), true) ::
+ StructField("arrayOfString", ArrayType(StringType, true), true) ::
+ StructField("arrayOfStruct", ArrayType(
+ StructType(
+ StructField("field1", StringType, true) ::
+ StructField("field2", StringType, true) ::
+ StructField("field3", StringType, true) :: Nil), true), true) ::
+ StructField("struct", StructType(
+ StructField("field1", StringType, true) ::
+ StructField("field2", StringType, true) :: Nil), true) ::
+ StructField("structWithArrayFields", StructType(
+ StructField("field1", ArrayType(StringType, true), true) ::
+ StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+
+ jsonDF.registerTempTable("jsonTable")
+
+ // Access elements of a primitive array.
+ checkAnswer(
+ sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"),
+ Row("str1", "str2", null)
+ )
+
+ // Access an array of null values.
+ checkAnswer(
+ sql("select arrayOfNull from jsonTable"),
+ Row(Seq(null, null, null, null))
+ )
+
+ // Access elements of a BigInteger array (we use DecimalType internally).
+ checkAnswer(
+ sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"),
+ Row("922337203685477580700", "-922337203685477580800", null)
+ )
+
+ // Access elements of an array of arrays.
+ checkAnswer(
+ sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"),
+ Row(Seq("1", "2", "3"), Seq("str1", "str2"))
+ )
+
+ // Access elements of an array of arrays.
+ checkAnswer(
+ sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"),
+ Row(Seq("1", "2", "3"), Seq("1.1", "2.1", "3.1"))
+ )
+
+ // Access elements of an array inside a filed with the type of ArrayType(ArrayType).
+ checkAnswer(
+ sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"),
+ Row("str2", "2.1")
+ )
+
+ // Access elements of an array of structs.
+ checkAnswer(
+ sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " +
+ "from jsonTable"),
+ Row(
+ Row("true", "str1", null),
+ Row("false", null, null),
+ Row(null, null, null),
+ null)
+ )
+
+ // Access a struct and fields inside of it.
+ checkAnswer(
+ sql("select struct, struct.field1, struct.field2 from jsonTable"),
+ Row(
+ Row("true", "92233720368547758070"),
+ "true",
+ "92233720368547758070") :: Nil
+ )
+
+ // Access an array field of a struct.
+ checkAnswer(
+ sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"),
+ Row(Seq("4", "5", "6"), Seq("str1", "str2"))
+ )
+
+ // Access elements of an array field of a struct.
+ checkAnswer(
+ sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"),
+ Row("5", null)
+ )
+ }
+
test("Loading a JSON dataset from a text file with SQL") {
val dir = Utils.createTempDir()
dir.delete()
@@ -960,9 +1090,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val jsonDF = sqlContext.read.json(primitiveFieldAndType)
val primTable = sqlContext.read.json(jsonDF.toJSON)
- primTable.registerTempTable("primativeTable")
+ primTable.registerTempTable("primitiveTable")
checkAnswer(
- sql("select * from primativeTable"),
+ sql("select * from primitiveTable"),
Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
@@ -1039,24 +1169,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val relation0 = new JSONRelation(
Some(empty),
1.0,
+ false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
+ false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
+ false,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
+ false,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
None, None)(sqlContext)
val logicalRelation3 = LogicalRelation(relation3)