aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorNathan Howell <nhowell@godaddy.com>2015-05-06 22:56:53 -0700
committerYin Huai <yhuai@databricks.com>2015-05-06 22:56:53 -0700
commit2d6612cc8b98f767d73c4d15e4065bf3d6c12ea7 (patch)
treeb8b410071f36da0a5aaa22cdcd7fc2cdbb66aa16 /sql/core
parent9cfa9a516ed991de6c5900c7285b47380a396142 (diff)
downloadspark-2d6612cc8b98f767d73c4d15e4065bf3d6c12ea7.tar.gz
spark-2d6612cc8b98f767d73c4d15e4065bf3d6c12ea7.tar.bz2
spark-2d6612cc8b98f767d73c4d15e4065bf3d6c12ea7.zip
[SPARK-5938] [SPARK-5443] [SQL] Improve JsonRDD performance
This patch comprises of a few related pieces of work: * Schema inference is performed directly on the JSON token stream * `String => Row` conversion populate Spark SQL structures without intermediate types * Projection pushdown is implemented via CatalystScan for DataFrame queries * Support for the legacy parser by setting `spark.sql.json.useJacksonStreamingAPI` to `false` Performance improvements depend on the schema and queries being executed, but it should be faster across the board. Below are benchmarks using the last.fm Million Song dataset: ``` Command | Baseline | Patched ---------------------------------------------------|----------|-------- import sqlContext.implicits._ | | val df = sqlContext.jsonFile("/tmp/lastfm.json") | 70.0s | 14.6s df.count() | 28.8s | 6.2s df.rdd.count() | 35.3s | 21.5s df.where($"artist" === "Robert Hood").collect() | 28.3s | 16.9s ``` To prepare this dataset for benchmarking, follow these steps: ``` # Fetch the datasets from http://labrosa.ee.columbia.edu/millionsong/lastfm wget http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_test.zip \ http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_train.zip # Decompress and combine, pipe through `jq -c` to ensure there is one record per line unzip -p lastfm_test.zip lastfm_train.zip | jq -c . > lastfm.json ``` Author: Nathan Howell <nhowell@godaddy.com> Closes #5801 from NathanHowell/json-performance and squashes the following commits: 26fea31 [Nathan Howell] Recreate the baseRDD each for each scan operation a7ebeb2 [Nathan Howell] Increase coverage of inserts into a JSONRelation e06a1dd [Nathan Howell] Add comments to the `useJacksonStreamingAPI` config flag 6822712 [Nathan Howell] Split up JsonRDD2 into multiple objects fa8234f [Nathan Howell] Wrap long lines b31917b [Nathan Howell] Rename `useJsonRDD2` to `useJacksonStreamingAPI` 15c5d1b [Nathan Howell] JSONRelation's baseRDD need not be lazy f8add6e [Nathan Howell] Add comments on lack of support for precision and scale DecimalTypes fa0be47 [Nathan Howell] Remove unused default case in the field parser 80dba17 [Nathan Howell] Add comments regarding null handling and empty strings 842846d [Nathan Howell] Point the empty schema inference test at JsonRDD2 ab6ee87 [Nathan Howell] Add projection pushdown support to JsonRDD/JsonRDD2 f636c14 [Nathan Howell] Enable JsonRDD2 by default, add a flag to switch back to JsonRDD 0bbc445 [Nathan Howell] Improve JSON parsing and type inference performance 7ca70c1 [Nathan Howell] Eliminate arrow pattern, replace with pattern matches
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala171
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala77
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala215
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala50
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala55
11 files changed, 688 insertions, 108 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 9d2cd7aae3..79fbf50300 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
-import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
@@ -1415,7 +1415,7 @@ class DataFrame private[sql](
new Iterator[String] {
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
- JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
+ JacksonGenerator(rowSchema, gen)(iter.next())
gen.flush()
val json = writer.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 3ffc2091d6..bfaddd0f2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -73,6 +73,8 @@ private[spark] object SQLConf {
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
+ val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -167,6 +169,12 @@ private[sql] class SQLConf extends Serializable {
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
/**
+ * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
+ */
+ private[spark] def useJacksonStreamingAPI: Boolean =
+ getConf(USE_JACKSON_STREAMING_API, "true").toBoolean
+
+ /**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
* effectively disables auto conversion.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7eabb93c1e..0563430a6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -659,13 +659,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
- val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
- val appliedSchema =
- Option(schema).getOrElse(
- JsonRDD.nullTypeToStringType(
- JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- createDataFrame(rowRDD, appliedSchema, needsConversion = false)
+ if (conf.useJacksonStreamingAPI) {
+ baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this))
+ } else {
+ val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
+ val appliedSchema =
+ Option(schema).getOrElse(
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
+ createDataFrame(rowRDD, appliedSchema, needsConversion = false)
+ }
}
/**
@@ -689,12 +693,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
- val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
- val appliedSchema =
- JsonRDD.nullTypeToStringType(
- JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- createDataFrame(rowRDD, appliedSchema, needsConversion = false)
+ if (conf.useJacksonStreamingAPI) {
+ baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this))
+ } else {
+ val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
+ val appliedSchema =
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
+ createDataFrame(rowRDD, appliedSchema, needsConversion = false)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
new file mode 100644
index 0000000000..9c58b8e4bb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.json
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.types._
+
+private[sql] object InferSchema {
+ /**
+ * Infer the type of a collection of json records in three stages:
+ * 1. Infer the type of each record
+ * 2. Merge types by choosing the lowest type necessary to cover equal keys
+ * 3. Replace any remaining null fields with string, the top type
+ */
+ def apply(
+ json: RDD[String],
+ samplingRatio: Double = 1.0,
+ columnNameOfCorruptRecords: String): StructType = {
+ require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
+ val schemaData = if (samplingRatio > 0.99) {
+ json
+ } else {
+ json.sample(withReplacement = false, samplingRatio, 1)
+ }
+
+ // perform schema inference on each row and merge afterwards
+ schemaData.mapPartitions { iter =>
+ val factory = new JsonFactory()
+ iter.map { row =>
+ try {
+ val parser = factory.createParser(row)
+ parser.nextToken()
+ inferField(parser)
+ } catch {
+ case _: JsonParseException =>
+ StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))
+ }
+ }
+ }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match {
+ case st: StructType => nullTypeToStringType(st)
+ }
+ }
+
+ /**
+ * Infer the type of a json document from the parser's token stream
+ */
+ private def inferField(parser: JsonParser): DataType = {
+ import com.fasterxml.jackson.core.JsonToken._
+ parser.getCurrentToken match {
+ case null | VALUE_NULL => NullType
+
+ case FIELD_NAME =>
+ parser.nextToken()
+ inferField(parser)
+
+ case VALUE_STRING if parser.getTextLength < 1 =>
+ // Zero length strings and nulls have special handling to deal
+ // with JSON generators that do not distinguish between the two.
+ // To accurately infer types for empty strings that are really
+ // meant to represent nulls we assume that the two are isomorphic
+ // but will defer treating null fields as strings until all the
+ // record fields' types have been combined.
+ NullType
+
+ case VALUE_STRING => StringType
+ case START_OBJECT =>
+ val builder = Seq.newBuilder[StructField]
+ while (nextUntil(parser, END_OBJECT)) {
+ builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
+ }
+
+ StructType(builder.result().sortBy(_.name))
+
+ case START_ARRAY =>
+ // If this JSON array is empty, we use NullType as a placeholder.
+ // If this array is not empty in other JSON objects, we can resolve
+ // the type as we pass through all JSON objects.
+ var elementType: DataType = NullType
+ while (nextUntil(parser, END_ARRAY)) {
+ elementType = compatibleType(elementType, inferField(parser))
+ }
+
+ ArrayType(elementType)
+
+ case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+ import JsonParser.NumberType._
+ parser.getNumberType match {
+ // For Integer values, use LongType by default.
+ case INT | LONG => LongType
+ // Since we do not have a data type backed by BigInteger,
+ // when we see a Java BigInteger, we use DecimalType.
+ case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
+ case FLOAT | DOUBLE => DoubleType
+ }
+
+ case VALUE_TRUE | VALUE_FALSE => BooleanType
+ }
+ }
+
+ private def nullTypeToStringType(struct: StructType): StructType = {
+ val fields = struct.fields.map {
+ case StructField(fieldName, dataType, nullable, _) =>
+ val newType = dataType match {
+ case NullType => StringType
+ case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
+ case ArrayType(struct: StructType, containsNull) =>
+ ArrayType(nullTypeToStringType(struct), containsNull)
+ case struct: StructType =>nullTypeToStringType(struct)
+ case other: DataType => other
+ }
+
+ StructField(fieldName, newType, nullable)
+ }
+
+ StructType(fields)
+ }
+
+ /**
+ * Remove top-level ArrayType wrappers and merge the remaining schemas
+ */
+ private def compatibleRootType: (DataType, DataType) => DataType = {
+ case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2)
+ case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2)
+ case (ty1, ty2) => compatibleType(ty1, ty2)
+ }
+
+ /**
+ * Returns the most general data type for two given data types.
+ */
+ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
+ HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+ // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+ (t1, t2) match {
+ case (other: DataType, NullType) => other
+ case (NullType, other: DataType) => other
+ case (StructType(fields1), StructType(fields2)) =>
+ val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
+ case (name, fieldTypes) =>
+ val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
+ StructField(name, dataType, nullable = true)
+ }
+ StructType(newFields.toSeq.sortBy(_.name))
+
+ case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+
+ // strings and every string is a Json object.
+ case (_, _) => StringType
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index e3352d0278..c772cd1f53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -22,14 +22,16 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] class DefaultSource
- extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
+ extends RelationProvider
+ with SchemaRelationProvider
+ with CreatableRelationProvider {
private def checkPath(parameters: Map[String, String]): String = {
parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
@@ -42,7 +44,7 @@ private[sql] class DefaultSource
val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
- JSONRelation(path, samplingRatio, None)(sqlContext)
+ new JSONRelation(path, samplingRatio, None, sqlContext)
}
/** Returns a new base relation with the given schema and parameters. */
@@ -53,7 +55,7 @@ private[sql] class DefaultSource
val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
- JSONRelation(path, samplingRatio, Some(schema))(sqlContext)
+ new JSONRelation(path, samplingRatio, Some(schema), sqlContext)
}
override def createRelation(
@@ -101,32 +103,87 @@ private[sql] class DefaultSource
}
}
-private[sql] case class JSONRelation(
- path: String,
- samplingRatio: Double,
+private[sql] class JSONRelation(
+ // baseRDD is not immutable with respect to INSERT OVERWRITE
+ // and so it must be recreated at least as often as the
+ // underlying inputs are modified. To be safe, a function is
+ // used instead of a regular RDD value to ensure a fresh RDD is
+ // recreated for each and every operation.
+ baseRDD: () => RDD[String],
+ val path: Option[String],
+ val samplingRatio: Double,
userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends BaseRelation
with TableScan
- with InsertableRelation {
-
- // TODO: Support partitioned JSON relation.
- private def baseRDD = sqlContext.sparkContext.textFile(path)
+ with InsertableRelation
+ with CatalystScan {
+
+ def this(
+ path: String,
+ samplingRatio: Double,
+ userSpecifiedSchema: Option[StructType],
+ sqlContext: SQLContext) =
+ this(
+ () => sqlContext.sparkContext.textFile(path),
+ Some(path),
+ samplingRatio,
+ userSpecifiedSchema)(sqlContext)
+
+ private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI
override val needConversion: Boolean = false
- override val schema = userSpecifiedSchema.getOrElse(
- JsonRDD.nullTypeToStringType(
- JsonRDD.inferSchema(
- baseRDD,
+ override lazy val schema = userSpecifiedSchema.getOrElse {
+ if (useJacksonStreamingAPI) {
+ InferSchema(
+ baseRDD(),
samplingRatio,
- sqlContext.conf.columnNameOfCorruptRecord)))
+ sqlContext.conf.columnNameOfCorruptRecord)
+ } else {
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(
+ baseRDD(),
+ samplingRatio,
+ sqlContext.conf.columnNameOfCorruptRecord))
+ }
+ }
- override def buildScan(): RDD[Row] =
- JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord)
+ override def buildScan(): RDD[Row] = {
+ if (useJacksonStreamingAPI) {
+ JacksonParser(
+ baseRDD(),
+ schema,
+ sqlContext.conf.columnNameOfCorruptRecord)
+ } else {
+ JsonRDD.jsonStringToRow(
+ baseRDD(),
+ schema,
+ sqlContext.conf.columnNameOfCorruptRecord)
+ }
+ }
+
+ override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = {
+ if (useJacksonStreamingAPI) {
+ JacksonParser(
+ baseRDD(),
+ StructType.fromAttributes(requiredColumns),
+ sqlContext.conf.columnNameOfCorruptRecord)
+ } else {
+ JsonRDD.jsonStringToRow(
+ baseRDD(),
+ StructType.fromAttributes(requiredColumns),
+ sqlContext.conf.columnNameOfCorruptRecord)
+ }
+ }
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
- val filesystemPath = new Path(path)
+ val filesystemPath = path match {
+ case Some(p) => new Path(p)
+ case None =>
+ throw new IOException(s"Cannot INSERT into table with no path defined")
+ }
+
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
if (overwrite) {
@@ -147,7 +204,7 @@ private[sql] case class JSONRelation(
}
}
// Write the data.
- data.toJSON.saveAsTextFile(path)
+ data.toJSON.saveAsTextFile(filesystemPath.toString)
// Right now, we assume that the schema is not changed. We will not update the schema.
// schema = data.schema
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
new file mode 100644
index 0000000000..80bf74aa02
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.json
+
+import scala.collection.Map
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+private[sql] object JacksonGenerator {
+ /** Transforms a single Row to JSON using Jackson
+ *
+ * @param rowSchema the schema object used for conversion
+ * @param gen a JsonGenerator object
+ * @param row The row to convert
+ */
+ def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = {
+ def valWriter: (DataType, Any) => Unit = {
+ case (_, null) | (NullType, _) => gen.writeNull()
+ case (StringType, v: String) => gen.writeString(v)
+ case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
+ case (IntegerType, v: Int) => gen.writeNumber(v)
+ case (ShortType, v: Short) => gen.writeNumber(v)
+ case (FloatType, v: Float) => gen.writeNumber(v)
+ case (DoubleType, v: Double) => gen.writeNumber(v)
+ case (LongType, v: Long) => gen.writeNumber(v)
+ case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
+ case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
+ case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
+ case (BooleanType, v: Boolean) => gen.writeBoolean(v)
+ case (DateType, v) => gen.writeString(v.toString)
+ case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
+
+ case (ArrayType(ty, _), v: Seq[_] ) =>
+ gen.writeStartArray()
+ v.foreach(valWriter(ty,_))
+ gen.writeEndArray()
+
+ case (MapType(kv,vv, _), v: Map[_,_]) =>
+ gen.writeStartObject()
+ v.foreach { p =>
+ gen.writeFieldName(p._1.toString)
+ valWriter(vv,p._2)
+ }
+ gen.writeEndObject()
+
+ case (StructType(ty), v: Row) =>
+ gen.writeStartObject()
+ ty.zip(v.toSeq).foreach {
+ case (_, null) =>
+ case (field, v) =>
+ gen.writeFieldName(field.name)
+ valWriter(field.dataType, v)
+ }
+ gen.writeEndObject()
+ }
+
+ valWriter(rowSchema, row)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
new file mode 100644
index 0000000000..a8e69ae611
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.json
+
+import java.io.ByteArrayOutputStream
+import java.sql.Timestamp
+
+import scala.collection.Map
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.types._
+
+private[sql] object JacksonParser {
+ def apply(
+ json: RDD[String],
+ schema: StructType,
+ columnNameOfCorruptRecords: String): RDD[Row] = {
+ parseJson(json, schema, columnNameOfCorruptRecords)
+ }
+
+ /**
+ * Parse the current token (and related children) according to a desired schema
+ */
+ private[sql] def convertField(
+ factory: JsonFactory,
+ parser: JsonParser,
+ schema: DataType): Any = {
+ import com.fasterxml.jackson.core.JsonToken._
+ (parser.getCurrentToken, schema) match {
+ case (null | VALUE_NULL, _) =>
+ null
+
+ case (FIELD_NAME, _) =>
+ parser.nextToken()
+ convertField(factory, parser, schema)
+
+ case (VALUE_STRING, StringType) =>
+ UTF8String(parser.getText)
+
+ case (VALUE_STRING, _) if parser.getTextLength < 1 =>
+ // guard the non string type
+ null
+
+ case (VALUE_STRING, DateType) =>
+ DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime)
+
+ case (VALUE_STRING, TimestampType) =>
+ new Timestamp(DateUtils.stringToTime(parser.getText).getTime)
+
+ case (VALUE_NUMBER_INT, TimestampType) =>
+ new Timestamp(parser.getLongValue)
+
+ case (_, StringType) =>
+ val writer = new ByteArrayOutputStream()
+ val generator = factory.createGenerator(writer, JsonEncoding.UTF8)
+ generator.copyCurrentStructure(parser)
+ generator.close()
+ UTF8String(writer.toByteArray)
+
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) =>
+ parser.getFloatValue
+
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
+ parser.getDoubleValue
+
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) =>
+ // TODO: add fixed precision and scale handling
+ Decimal(parser.getDecimalValue)
+
+ case (VALUE_NUMBER_INT, ByteType) =>
+ parser.getByteValue
+
+ case (VALUE_NUMBER_INT, ShortType) =>
+ parser.getShortValue
+
+ case (VALUE_NUMBER_INT, IntegerType) =>
+ parser.getIntValue
+
+ case (VALUE_NUMBER_INT, LongType) =>
+ parser.getLongValue
+
+ case (VALUE_TRUE, BooleanType) =>
+ true
+
+ case (VALUE_FALSE, BooleanType) =>
+ false
+
+ case (START_OBJECT, st: StructType) =>
+ convertObject(factory, parser, st)
+
+ case (START_ARRAY, ArrayType(st, _)) =>
+ convertList(factory, parser, st)
+
+ case (START_OBJECT, ArrayType(st, _)) =>
+ // the business end of SPARK-3308:
+ // when an object is found but an array is requested just wrap it in a list
+ convertField(factory, parser, st) :: Nil
+
+ case (START_OBJECT, MapType(StringType, kt, _)) =>
+ convertMap(factory, parser, kt)
+
+ case (_, udt: UserDefinedType[_]) =>
+ udt.deserialize(convertField(factory, parser, udt.sqlType))
+ }
+ }
+
+ /**
+ * Parse an object from the token stream into a new Row representing the schema.
+ *
+ * Fields in the json that are not defined in the requested schema will be dropped.
+ */
+ private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = {
+ val row = new GenericMutableRow(schema.length)
+ while (nextUntil(parser, JsonToken.END_OBJECT)) {
+ schema.getFieldIndex(parser.getCurrentName) match {
+ case Some(index) =>
+ row.update(index, convertField(factory, parser, schema(index).dataType))
+
+ case None =>
+ parser.skipChildren()
+ }
+ }
+
+ row
+ }
+
+ /**
+ * Parse an object as a Map, preserving all fields
+ */
+ private def convertMap(
+ factory: JsonFactory,
+ parser: JsonParser,
+ valueType: DataType): Map[String, Any] = {
+ val builder = Map.newBuilder[String, Any]
+ while (nextUntil(parser, JsonToken.END_OBJECT)) {
+ builder += parser.getCurrentName -> convertField(factory, parser, valueType)
+ }
+
+ builder.result()
+ }
+
+ private def convertList(
+ factory: JsonFactory,
+ parser: JsonParser,
+ schema: DataType): Seq[Any] = {
+ val builder = Seq.newBuilder[Any]
+ while (nextUntil(parser, JsonToken.END_ARRAY)) {
+ builder += convertField(factory, parser, schema)
+ }
+
+ builder.result()
+ }
+
+ private def parseJson(
+ json: RDD[String],
+ schema: StructType,
+ columnNameOfCorruptRecords: String): RDD[Row] = {
+
+ def failedRecord(record: String): Seq[Row] = {
+ // create a row even if no corrupt record column is present
+ val row = new GenericMutableRow(schema.length)
+ for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
+ require(schema(corruptIndex).dataType == StringType)
+ row.update(corruptIndex, record)
+ }
+
+ Seq(row)
+ }
+
+ json.mapPartitions { iter =>
+ val factory = new JsonFactory()
+
+ iter.flatMap { record =>
+ try {
+ val parser = factory.createParser(record)
+ parser.nextToken()
+
+ // to support both object and arrays (see SPARK-3308) we'll start
+ // by converting the StructType schema to an ArrayType and let
+ // convertField wrap an object into a single value array when necessary.
+ convertField(factory, parser, ArrayType(schema)) match {
+ case null => failedRecord(record)
+ case list: Seq[Row @unchecked] => list
+ case _ =>
+ sys.error(
+ s"Failed to parse record $record. Please make sure that each line of the file " +
+ "(or each string in the RDD) is a valid JSON object or an array of JSON objects.")
+ }
+ } catch {
+ case _: JsonProcessingException =>
+ failedRecord(record)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala
new file mode 100644
index 0000000000..fde96852ce
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.json
+
+import com.fasterxml.jackson.core.{JsonParser, JsonToken}
+
+private object JacksonUtils {
+ /**
+ * Advance the parser until a null or a specific token is found
+ */
+ def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
+ parser.nextToken() match {
+ case null => false
+ case x => x != stopOn
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 6e94e7056e..f62973d5fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -440,54 +440,4 @@ private[sql] object JsonRDD extends Logging {
row
}
-
- /** Transforms a single Row to JSON using Jackson
- *
- * @param rowSchema the schema object used for conversion
- * @param gen a JsonGenerator object
- * @param row The row to convert
- */
- private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
- def valWriter: (DataType, Any) => Unit = {
- case (_, null) | (NullType, _) => gen.writeNull()
- case (StringType, v: String) => gen.writeString(v)
- case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
- case (IntegerType, v: Int) => gen.writeNumber(v)
- case (ShortType, v: Short) => gen.writeNumber(v)
- case (FloatType, v: Float) => gen.writeNumber(v)
- case (DoubleType, v: Double) => gen.writeNumber(v)
- case (LongType, v: Long) => gen.writeNumber(v)
- case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
- case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
- case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
- case (BooleanType, v: Boolean) => gen.writeBoolean(v)
- case (DateType, v) => gen.writeString(v.toString)
- case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)
-
- case (ArrayType(ty, _), v: Seq[_] ) =>
- gen.writeStartArray()
- v.foreach(valWriter(ty,_))
- gen.writeEndArray()
-
- case (MapType(kv,vv, _), v: Map[_,_]) =>
- gen.writeStartObject()
- v.foreach { p =>
- gen.writeFieldName(p._1.toString)
- valWriter(vv,p._2)
- }
- gen.writeEndObject()
-
- case (StructType(ty), v: Row) =>
- gen.writeStartObject()
- ty.zip(v.toSeq).foreach {
- case (_, null) =>
- case (field, v) =>
- gen.writeFieldName(field.name)
- valWriter(field.dataType, v)
- }
- gen.writeEndObject()
- }
-
- valWriter(rowSchema, row)
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index fd0e2746dc..263fafba93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.json
+import java.io.StringWriter
import java.sql.{Date, Timestamp}
+import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
+import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@@ -46,6 +48,18 @@ class JsonSuite extends QueryTest {
s"${expected}(${expected.getClass}).")
}
+ val factory = new JsonFactory()
+ def enforceCorrectType(value: Any, dataType: DataType): Any = {
+ val writer = new StringWriter()
+ val generator = factory.createGenerator(writer)
+ generator.writeObject(value)
+ generator.flush()
+
+ val parser = factory.createParser(writer.toString)
+ parser.nextToken()
+ JacksonParser.convertField(factory, parser, dataType)
+ }
+
val intNumber: Int = 2147483647
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
@@ -439,7 +453,7 @@ class JsonSuite extends QueryTest {
val jsonDF = jsonRDD(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
- // Right now, the analyzer does not promote strings in a boolean expreesion.
+ // Right now, the analyzer does not promote strings in a boolean expression.
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where NOT num_bool"),
@@ -508,7 +522,7 @@ class JsonSuite extends QueryTest {
Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
Row(null, """{"field":false}""", null, null, "{}") ::
Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
- Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil
+ Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil
)
}
@@ -566,19 +580,19 @@ class JsonSuite extends QueryTest {
val analyzed = jsonDF.queryExecution.analyzed
assert(
analyzed.isInstanceOf[LogicalRelation],
- "The DataFrame returned by jsonFile should be based on JSONRelation.")
+ "The DataFrame returned by jsonFile should be based on LogicalRelation.")
val relation = analyzed.asInstanceOf[LogicalRelation].relation
assert(
relation.isInstanceOf[JSONRelation],
"The DataFrame returned by jsonFile should be based on JSONRelation.")
- assert(relation.asInstanceOf[JSONRelation].path === path)
+ assert(relation.asInstanceOf[JSONRelation].path === Some(path))
assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001))
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
- assert(relationWithSchema.path === path)
+ assert(relationWithSchema.path === Some(path))
assert(relationWithSchema.schema === schema)
assert(relationWithSchema.samplingRatio > 0.99)
}
@@ -1020,15 +1034,24 @@ class JsonSuite extends QueryTest {
}
test("JSONRelation equality test") {
- val relation1 =
- JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null)
+ val context = org.apache.spark.sql.test.TestSQLContext
+ val relation1 = new JSONRelation(
+ "path",
+ 1.0,
+ Some(StructType(StructField("a", IntegerType, true) :: Nil)),
+ context)
val logicalRelation1 = LogicalRelation(relation1)
- val relation2 =
- JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(
- org.apache.spark.sql.test.TestSQLContext)
+ val relation2 = new JSONRelation(
+ "path",
+ 0.5,
+ Some(StructType(StructField("a", IntegerType, true) :: Nil)),
+ context)
val logicalRelation2 = LogicalRelation(relation2)
- val relation3 =
- JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null)
+ val relation3 = new JSONRelation(
+ "path",
+ 1.0,
+ Some(StructType(StructField("b", StringType, true) :: Nil)),
+ context)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation1 === relation2)
@@ -1046,7 +1069,7 @@ class JsonSuite extends QueryTest {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
- val emptySchema = JsonRDD.inferSchema(empty, 1.0, "")
+ val emptySchema = InferSchema(empty, 1.0, "")
assert(StructType(Seq()) === emptySchema)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 80efe9728f..50629ea4dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
@@ -100,23 +100,48 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
test("INSERT OVERWRITE a JSONRelation multiple times") {
sql(
s"""
- |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
- """.stripMargin)
+ |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i, s"str$i"))
+ )
+ // Writing the table to less part files.
+ val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5)
+ jsonRDD(rdd1).registerTempTable("jt1")
sql(
s"""
- |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
- """.stripMargin)
+ |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i, s"str$i"))
+ )
+ // Writing the table to more part files.
+ val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10)
+ jsonRDD(rdd2).registerTempTable("jt2")
sql(
s"""
- |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
- """.stripMargin)
-
+ |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2
+ """.stripMargin)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
)
+
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i * 10, s"str$i"))
+ )
+
+ dropTempTable("jt1")
+ dropTempTable("jt2")
}
test("INSERT INTO not supported for JSONRelation for now") {
@@ -128,6 +153,20 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
}
+ test("save directly to the path of a JSON table") {
+ table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite)
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i * 5, s"str$i"))
+ )
+
+ table("jt").save(path.toString, "json", SaveMode.Overwrite)
+ checkAnswer(
+ sql("SELECT a, b FROM jsonTable"),
+ (1 to 10).map(i => Row(i, s"str$i"))
+ )
+ }
+
test("it is not allowed to write to a table while querying it.") {
val message = intercept[AnalysisException] {
sql(