aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala80
4 files changed, 94 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 5323b79c57..8319ec0a82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -95,7 +95,8 @@ object HiveTypeCoercion {
Some(t1)
// Promote numeric types to the highest of the two
- case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
+ case (t1: NumericType, t2: NumericType)
+ if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 66f123682e..1fb2e2404c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -84,18 +84,20 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
override def equals(obj: Any): Boolean = {
obj match {
- case that: Metadata =>
- if (map.keySet == that.map.keySet) {
- map.keys.forall { k =>
- (map(k), that.map(k)) match {
- case (v0: Array[_], v1: Array[_]) =>
- v0.view == v1.view
- case (v0, v1) =>
- v0 == v1
- }
+ case that: Metadata if map.size == that.map.size =>
+ map.keysIterator.forall { key =>
+ that.map.get(key) match {
+ case Some(otherValue) =>
+ val ourValue = map.get(key).get
+ (ourValue, otherValue) match {
+ case (v0: Array[Long], v1: Array[Long]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[Double], v1: Array[Double]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[Boolean], v1: Array[Boolean]) => java.util.Arrays.equals(v0, v1)
+ case (v0: Array[AnyRef], v1: Array[AnyRef]) => java.util.Arrays.equals(v0, v1)
+ case (v0, v1) => v0 == v1
+ }
+ case None => false
}
- } else {
- false
}
case other =>
false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index b06aa7bc52..fd2b524e22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -103,6 +103,17 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
+ override def equals(that: Any): Boolean = {
+ that match {
+ case StructType(otherFields) =>
+ java.util.Arrays.equals(
+ fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = java.util.Arrays.hashCode(fields.asInstanceOf[Array[AnyRef]])
+
/**
* Creates a new [[StructType]] by adding a new field.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 8e8238a594..42c82625fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
+import java.util.Comparator
+
import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
@@ -63,9 +65,7 @@ private[sql] object InferSchema {
None
}
}
- }.treeAggregate[DataType](
- StructType(Seq()))(
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord),
+ }.fold(StructType(Seq()))(
compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord))
canonicalizeType(rootType) match {
@@ -76,6 +76,23 @@ private[sql] object InferSchema {
}
}
+ private[this] val structFieldComparator = new Comparator[StructField] {
+ override def compare(o1: StructField, o2: StructField): Int = {
+ o1.name.compare(o2.name)
+ }
+ }
+
+ private def isSorted(arr: Array[StructField]): Boolean = {
+ var i: Int = 0
+ while (i < arr.length - 1) {
+ if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+ return false
+ }
+ i += 1
+ }
+ true
+ }
+
/**
* Infer the type of a json document from the parser's token stream
*/
@@ -99,15 +116,17 @@ private[sql] object InferSchema {
case VALUE_STRING => StringType
case START_OBJECT =>
- val builder = Seq.newBuilder[StructField]
+ val builder = Array.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, configOptions),
nullable = true)
}
-
- StructType(builder.result().sortBy(_.name))
+ val fields: Array[StructField] = builder.result()
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(fields, structFieldComparator)
+ StructType(fields)
case START_ARRAY =>
// If this JSON array is empty, we use NullType as a placeholder.
@@ -191,7 +210,11 @@ private[sql] object InferSchema {
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
// If this given struct does not have a column used for corrupt records,
// add this field.
- struct.add(columnNameOfCorruptRecords, StringType, nullable = true)
+ val newFields: Array[StructField] =
+ StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(newFields, structFieldComparator)
+ StructType(newFields)
} else {
// Otherwise, just return this struct.
struct
@@ -223,6 +246,8 @@ private[sql] object InferSchema {
case (ty1, ty2) => compatibleType(ty1, ty2)
}
+ private[this] val emptyStructFieldArray = Array.empty[StructField]
+
/**
* Returns the most general data type for two given data types.
*/
@@ -246,12 +271,43 @@ private[sql] object InferSchema {
}
case (StructType(fields1), StructType(fields2)) =>
- val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
- case (name, fieldTypes) =>
- val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
- StructField(name, dataType, nullable = true)
+ // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+ // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+ // building a hash map or performing additional sorting.
+ assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
+ assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
+
+ val newFields = new java.util.ArrayList[StructField]()
+
+ var f1Idx = 0
+ var f2Idx = 0
+
+ while (f1Idx < fields1.length && f2Idx < fields2.length) {
+ val f1Name = fields1(f1Idx).name
+ val f2Name = fields2(f2Idx).name
+ val comp = f1Name.compareTo(f2Name)
+ if (comp == 0) {
+ val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+ newFields.add(StructField(f1Name, dataType, nullable = true))
+ f1Idx += 1
+ f2Idx += 1
+ } else if (comp < 0) { // f1Name < f2Name
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ } else { // f1Name > f2Name
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ }
+ while (f1Idx < fields1.length) {
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ }
+ while (f2Idx < fields2.length) {
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
}
- StructType(newFields.toSeq.sortBy(_.name))
+ StructType(newFields.toArray(emptyStructFieldArray))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)