aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorrowan <rowan.chattaway@googlemail.com>2015-05-26 18:17:16 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-26 18:17:16 -0700
commit03668348e29eb52c1a7d57a1e0ed7fca6c323890 (patch)
tree2cf07a61dbb121b87c83be158569c970e02e2b07
parent0463428b6e8f364f0b1f39445a60cd85ae7c07bc (diff)
downloadspark-03668348e29eb52c1a7d57a1e0ed7fca6c323890.tar.gz
spark-03668348e29eb52c1a7d57a1e0ed7fca6c323890.tar.bz2
spark-03668348e29eb52c1a7d57a1e0ed7fca6c323890.zip
[SPARK-7637] [SQL] O(N) merge implementation for StructType merge
Contribution is my original work and I license the work to the project under the projects open source license. Author: rowan <rowan.chattaway@googlemail.com> Closes #6259 from rowan000/SPARK-7637 and squashes the following commits: c479df4 [rowan] SPARK-7637: rename mapFields to fieldsMap as per comments on github. 8d2e419 [rowan] SPARK-7637: fix up whitespace changes 0e9d662 [rowan] SPARK-7637: O(N) merge implementatio for StructType merge
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala73
2 files changed, 81 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 7e00a27dfe..a4f30c825b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -230,10 +230,10 @@ object StructType {
case (StructType(leftFields), StructType(rightFields)) =>
val newFields = ArrayBuffer.empty[StructField]
+ val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
- rightFields
- .find(_.name == leftName)
+ rightMapped.get(leftName)
.map { case rightField @ StructField(_, rightType, rightNullable, _) =>
leftField.copy(
dataType = merge(leftType, rightType),
@@ -243,8 +243,9 @@ object StructType {
.foreach(newFields += _)
}
+ val leftMapped = fieldsMap(leftFields)
rightFields
- .filterNot(f => leftFields.map(_.name).contains(f.name))
+ .filterNot(f => leftMapped.get(f.name).nonEmpty)
.foreach(newFields += _)
StructType(newFields)
@@ -264,4 +265,9 @@ object StructType {
case _ =>
throw new SparkException(s"Failed to merge incompatible data types $left and $right")
}
+
+ private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = {
+ import scala.collection.breakOut
+ fields.map(s => (s.name, s))(breakOut)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index d797510f36..a73317c869 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.types
+import org.apache.spark.SparkException
import org.scalatest.FunSuite
class DataTypeSuite extends FunSuite {
@@ -69,6 +70,76 @@ class DataTypeSuite extends FunSuite {
}
}
+ test("fieldsMap returns map of name to StructField") {
+ val struct = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ val mapped = StructType.fieldsMap(struct.fields)
+
+ val expected = Map(
+ "a" -> StructField("a", LongType),
+ "b" -> StructField("b", FloatType))
+
+ assert(mapped === expected)
+ }
+
+ test("merge where right is empty") {
+ val left = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ val right = StructType(List())
+ val merged = left.merge(right)
+
+ assert(merged === left)
+ }
+
+ test("merge where left is empty") {
+
+ val left = StructType(List())
+
+ val right = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ val merged = left.merge(right)
+
+ assert(right === merged)
+
+ }
+
+ test("merge where both are non-empty") {
+ val left = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ val right = StructType(
+ StructField("c", LongType) :: Nil)
+
+ val expected = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) ::
+ StructField("c", LongType) :: Nil)
+
+ val merged = left.merge(right)
+
+ assert(merged === expected)
+ }
+
+ test("merge where right contains type conflict") {
+ val left = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+
+ val right = StructType(
+ StructField("b", LongType) :: Nil)
+
+ intercept[SparkException] {
+ left.merge(right)
+ }
+ }
+
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
@@ -120,7 +191,7 @@ class DataTypeSuite extends FunSuite {
checkDefaultSize(DecimalType(10, 5), 4096)
checkDefaultSize(DecimalType.Unlimited, 4096)
checkDefaultSize(DateType, 4)
- checkDefaultSize(TimestampType,12)
+ checkDefaultSize(TimestampType, 12)
checkDefaultSize(StringType, 4096)
checkDefaultSize(BinaryType, 4096)
checkDefaultSize(ArrayType(DoubleType, true), 800)