aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala21
2 files changed, 24 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 19421e5667..917b346086 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -115,7 +115,9 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def dataType: DataType = {
assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
- ArrayType(childTypes.headOption.getOrElse(NullType))
+ ArrayType(
+ childTypes.headOption.getOrElse(NullType),
+ containsNull = children.exists(_.nullable))
}
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 5dd19dd12d..ff1dc03069 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -171,6 +171,27 @@ object DataType {
case _ =>
}
}
+
+ /**
+ * Compares two types, ignoring nullability of ArrayType, MapType, StructType.
+ */
+ def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ (left, right) match {
+ case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
+ equalsIgnoreNullability(leftElementType, rightElementType)
+ case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
+ equalsIgnoreNullability(leftKeyType, rightKeyType) &&
+ equalsIgnoreNullability(leftValueType, rightValueType)
+ case (StructType(leftFields), StructType(rightFields)) =>
+ leftFields.size == rightFields.size &&
+ leftFields.zip(rightFields)
+ .forall{
+ case (left, right) =>
+ left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
+ }
+ case (left, right) => left == right
+ }
+ }
}
abstract class DataType {