aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-05-09 13:11:18 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-09 13:11:18 -0700
commitc3350cadb8369ad016f89135bbcbe126705c463c (patch)
tree0a18d5d85de0de13ac67a5c041b73d96ec58bd9a /sql
parent2adb11f6db591a7d8199e42dd23c7fb23ef5df3b (diff)
downloadspark-c3350cadb8369ad016f89135bbcbe126705c463c.tar.gz
spark-c3350cadb8369ad016f89135bbcbe126705c463c.tar.bz2
spark-c3350cadb8369ad016f89135bbcbe126705c463c.zip
[SPARK-14972] Improve performance of JSON schema inference's compatibleType method
This patch improves the performance of `InferSchema.compatibleType` and `inferField`. The net result of this patch is a 6x speedup in local benchmarks running against cached data with a massive nested schema. The key idea is to remove unnecessary sorting in `compatibleType`'s `StructType` merging code. This code takes two structs, merges the fields with matching names, and copies over the unique fields, producing a new schema which is the union of the two structs' schemas. Previously, this code performed a very inefficient `groupBy()` to match up fields with the same name, but this is unnecessary because `inferField` already sorts structs' fields by name: since both lists of fields are sorted, we can simply merge them in a single pass. This patch also speeds up the existing field sorting in `inferField`: the old sorting code allocated unnecessary intermediate collections, while the new code uses mutable collects and performs in-place sorting. I rewrote inefficient `equals()` implementations in `StructType` and `Metadata`, significantly reducing object allocations in those methods. Finally, I replaced a `treeAggregate` call with `fold`: I doubt that `treeAggregate` will benefit us very much because the schemas would have to be enormous to realize large savings in network traffic. Since most schemas are probably fairly small in serialized form, they should typically fit within a direct task result and therefore can be incrementally merged at the driver as individual tasks finish. This change eliminates an entire (short) scheduler stage. Author: Josh Rosen <joshrosen@databricks.com> Closes #12750 from JoshRosen/schema-inference-speedups.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala80
4 files changed, 94 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 5323b79c57..8319ec0a82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -95,7 +95,8 @@ object HiveTypeCoercion {
Some(t1)
// Promote numeric types to the highest of the two
- case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
+ case (t1: NumericType, t2: NumericType)
+ if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 66f123682e..1fb2e2404c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -84,18 +84,20 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
override def equals(obj: Any): Boolean = {
obj match {
- case that: Metadata =>
- if (map.keySet == that.map.keySet) {
- map.keys.forall { k =>
- (map(k), that.map(k)) match {
- case (v0: Array[_], v1: Array[_]) =>
- v0.view == v1.view
- case (v0, v1) =>
- v0 == v1
- }
+ case that: Metadata if map.size == that.map.size =>
+ map.keysIterator.forall { key =>
+ that.map.get(key) match {
+ case Some(otherValue) =>
+ val ourValue = map.get(key).get
+ (ourValue, otherValue) match {
+ case (v0: Array[Long], v1: Array[Long]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[Double], v1: Array[Double]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[Boolean], v1: Array[Boolean]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[AnyRef], v1: Array[AnyRef]) => java.util.Arrays.equals(v0, v1)
+ case (v0, v1) => v0 == v1
+ }
+ case None => false
}
- } else {
- false
}
case other =>
false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index b06aa7bc52..fd2b524e22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -103,6 +103,17 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
+ override def equals(that: Any): Boolean = {
+ that match {
+ case StructType(otherFields) =>
+ java.util.Arrays.equals(
+ fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = java.util.Arrays.hashCode(fields.asInstanceOf[Array[AnyRef]])
+
/**
* Creates a new [[StructType]] by adding a new field.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 8e8238a594..42c82625fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
+import java.util.Comparator
+
import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
@@ -63,9 +65,7 @@ private[sql] object InferSchema {
None
}
}
- }.treeAggregate[DataType](
- StructType(Seq()))(
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord),
+ }.fold(StructType(Seq()))(
compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord))
canonicalizeType(rootType) match {
@@ -76,6 +76,23 @@ private[sql] object InferSchema {
}
}
+ private[this] val structFieldComparator = new Comparator[StructField] {
+ override def compare(o1: StructField, o2: StructField): Int = {
+ o1.name.compare(o2.name)
+ }
+ }
+
+ private def isSorted(arr: Array[StructField]): Boolean = {
+ var i: Int = 0
+ while (i < arr.length - 1) {
+ if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+ return false
+ }
+ i += 1
+ }
+ true
+ }
+
/**
* Infer the type of a json document from the parser's token stream
*/
@@ -99,15 +116,17 @@ private[sql] object InferSchema {
case VALUE_STRING => StringType
case START_OBJECT =>
- val builder = Seq.newBuilder[StructField]
+ val builder = Array.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, configOptions),
nullable = true)
}
-
- StructType(builder.result().sortBy(_.name))
+ val fields: Array[StructField] = builder.result()
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(fields, structFieldComparator)
+ StructType(fields)
case START_ARRAY =>
// If this JSON array is empty, we use NullType as a placeholder.
@@ -191,7 +210,11 @@ private[sql] object InferSchema {
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
// If this given struct does not have a column used for corrupt records,
// add this field.
- struct.add(columnNameOfCorruptRecords, StringType, nullable = true)
+ val newFields: Array[StructField] =
+ StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(newFields, structFieldComparator)
+ StructType(newFields)
} else {
// Otherwise, just return this struct.
struct
@@ -223,6 +246,8 @@ private[sql] object InferSchema {
case (ty1, ty2) => compatibleType(ty1, ty2)
}
+ private[this] val emptyStructFieldArray = Array.empty[StructField]
+
/**
* Returns the most general data type for two given data types.
*/
@@ -246,12 +271,43 @@ private[sql] object InferSchema {
}
case (StructType(fields1), StructType(fields2)) =>
- val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
- case (name, fieldTypes) =>
- val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
- StructField(name, dataType, nullable = true)
+ // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+ // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+ // building a hash map or performing additional sorting.
+ assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
+ assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
+
+ val newFields = new java.util.ArrayList[StructField]()
+
+ var f1Idx = 0
+ var f2Idx = 0
+
+ while (f1Idx < fields1.length && f2Idx < fields2.length) {
+ val f1Name = fields1(f1Idx).name
+ val f2Name = fields2(f2Idx).name
+ val comp = f1Name.compareTo(f2Name)
+ if (comp == 0) {
+ val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+ newFields.add(StructField(f1Name, dataType, nullable = true))
+ f1Idx += 1
+ f2Idx += 1
+ } else if (comp < 0) { // f1Name < f2Name
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ } else { // f1Name > f2Name
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ }
+ while (f1Idx < fields1.length) {
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ }
+ while (f2Idx < fields2.length) {
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
}
- StructType(newFields.toSeq.sortBy(_.name))
+ StructType(newFields.toArray(emptyStructFieldArray))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)