diff options
4 files changed, 94 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5323b79c57..8319ec0a82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -95,7 +95,8 @@ object HiveTypeCoercion { Some(t1) // Promote numeric types to the highest of the two - case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + case (t1: NumericType, t2: NumericType) + if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] => val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 66f123682e..1fb2e2404c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -84,18 +84,20 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) override def equals(obj: Any): Boolean = { obj match { - case that: Metadata => - if (map.keySet == that.map.keySet) { - map.keys.forall { k => - (map(k), that.map(k)) match { - case (v0: Array[_], v1: Array[_]) => - v0.view == v1.view - case (v0, v1) => - v0 == v1 - } + case that: Metadata if map.size == that.map.size => + map.keysIterator.forall { key => + that.map.get(key) match { + case Some(otherValue) => + val ourValue = map.get(key).get + (ourValue, otherValue) match { + case (v0: Array[Long], v1: Array[Long]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[Double], v1: Array[Double]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[Boolean], v1: Array[Boolean]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[AnyRef], v1: Array[AnyRef]) => java.util.Arrays.equals(v0, v1) + case (v0, v1) => v0 == v1 + } + case None => false } - } else { - false } case other => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b06aa7bc52..fd2b524e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -103,6 +103,17 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + override def equals(that: Any): Boolean = { + that match { + case StructType(otherFields) => + java.util.Arrays.equals( + fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) + case _ => false + } + } + + override def hashCode(): Int = java.util.Arrays.hashCode(fields.asInstanceOf[Array[AnyRef]]) + /** * Creates a new [[StructType]] by adding a new field. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 8e8238a594..42c82625fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.util.Comparator + import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD @@ -63,9 +65,7 @@ private[sql] object InferSchema { None } } - }.treeAggregate[DataType]( - StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord), + }.fold(StructType(Seq()))( compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)) canonicalizeType(rootType) match { @@ -76,6 +76,23 @@ private[sql] object InferSchema { } } + private[this] val structFieldComparator = new Comparator[StructField] { + override def compare(o1: StructField, o2: StructField): Int = { + o1.name.compare(o2.name) + } + } + + private def isSorted(arr: Array[StructField]): Boolean = { + var i: Int = 0 + while (i < arr.length - 1) { + if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { + return false + } + i += 1 + } + true + } + /** * Infer the type of a json document from the parser's token stream */ @@ -99,15 +116,17 @@ private[sql] object InferSchema { case VALUE_STRING => StringType case START_OBJECT => - val builder = Seq.newBuilder[StructField] + val builder = Array.newBuilder[StructField] while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, inferField(parser, configOptions), nullable = true) } - - StructType(builder.result().sortBy(_.name)) + val fields: Array[StructField] = builder.result() + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, structFieldComparator) + StructType(fields) case START_ARRAY => // If this JSON array is empty, we use NullType as a placeholder. @@ -191,7 +210,11 @@ private[sql] object InferSchema { if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { // If this given struct does not have a column used for corrupt records, // add this field. - struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) } else { // Otherwise, just return this struct. struct @@ -223,6 +246,8 @@ private[sql] object InferSchema { case (ty1, ty2) => compatibleType(ty1, ty2) } + private[this] val emptyStructFieldArray = Array.empty[StructField] + /** * Returns the most general data type for two given data types. */ @@ -246,12 +271,43 @@ private[sql] object InferSchema { } case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => - val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) + // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. + // Therefore, we can take advantage of the fact that we're merging sorted lists and skip + // building a hash map or performing additional sorting. + assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + + val newFields = new java.util.ArrayList[StructField]() + + var f1Idx = 0 + var f2Idx = 0 + + while (f1Idx < fields1.length && f2Idx < fields2.length) { + val f1Name = fields1(f1Idx).name + val f2Name = fields2(f2Idx).name + val comp = f1Name.compareTo(f2Name) + if (comp == 0) { + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + newFields.add(StructField(f1Name, dataType, nullable = true)) + f1Idx += 1 + f2Idx += 1 + } else if (comp < 0) { // f1Name < f2Name + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } else { // f1Name > f2Name + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + } + while (f1Idx < fields1.length) { + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } + while (f2Idx < fields2.length) { + newFields.add(fields2(f2Idx)) + f2Idx += 1 } - StructType(newFields.toSeq.sortBy(_.name)) + StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) |