aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2017-04-12 01:30:00 -0700
committerReynold Xin <rxin@databricks.com>2017-04-12 01:30:00 -0700
commitffc57b0118b58de57520967d8e8730b11baad507 (patch)
tree5d97391e4280eabf11cc896ad720556bddbc4d46 /sql/catalyst
parent044f7ecbfd75ac5a13bfc8cd01990e195c9bd178 (diff)
downloadspark-ffc57b0118b58de57520967d8e8730b11baad507.tar.gz
spark-ffc57b0118b58de57520967d8e8730b11baad507.tar.bz2
spark-ffc57b0118b58de57520967d8e8730b11baad507.zip
[SPARK-20302][SQL] Short circuit cast when from and to types are structurally the same
## What changes were proposed in this pull request? When we perform a cast expression and the from and to types are structurally the same (having the same structure but different field names), we should be able to skip the actual cast. ## How was this patch tested? Added unit tests for the newly introduced functions. Author: Reynold Xin <rxin@databricks.com> Closes #17614 from rxin/SPARK-20302.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala65
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala26
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala31
4 files changed, 113 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1049915986..bb1273f5c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
})
}
- private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
- case dt if dt == from => identity[Any]
- case StringType => castToString(from)
- case BinaryType => castToBinary(from)
- case DateType => castToDate(from)
- case decimal: DecimalType => castToDecimal(from, decimal)
- case TimestampType => castToTimestamp(from)
- case CalendarIntervalType => castToInterval(from)
- case BooleanType => castToBoolean(from)
- case ByteType => castToByte(from)
- case ShortType => castToShort(from)
- case IntegerType => castToInt(from)
- case FloatType => castToFloat(from)
- case LongType => castToLong(from)
- case DoubleType => castToDouble(from)
- case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
- case map: MapType => castMap(from.asInstanceOf[MapType], map)
- case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
- case udt: UserDefinedType[_]
- if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
- identity[Any]
- case _: UserDefinedType[_] =>
- throw new SparkException(s"Cannot cast $from to $to.")
+ private[this] def cast(from: DataType, to: DataType): Any => Any = {
+ // If the cast does not change the structure, then we don't really need to cast anything.
+ // We can return what the children return. Same thing should happen in the codegen path.
+ if (DataType.equalsStructurally(from, to)) {
+ identity
+ } else {
+ to match {
+ case dt if dt == from => identity[Any]
+ case StringType => castToString(from)
+ case BinaryType => castToBinary(from)
+ case DateType => castToDate(from)
+ case decimal: DecimalType => castToDecimal(from, decimal)
+ case TimestampType => castToTimestamp(from)
+ case CalendarIntervalType => castToInterval(from)
+ case BooleanType => castToBoolean(from)
+ case ByteType => castToByte(from)
+ case ShortType => castToShort(from)
+ case IntegerType => castToInt(from)
+ case FloatType => castToFloat(from)
+ case LongType => castToLong(from)
+ case DoubleType => castToDouble(from)
+ case array: ArrayType =>
+ castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
+ case map: MapType => castMap(from.asInstanceOf[MapType], map)
+ case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+ case udt: UserDefinedType[_]
+ if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+ identity[Any]
+ case _: UserDefinedType[_] =>
+ throw new SparkException(s"Cannot cast $from to $to.")
+ }
+ }
}
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
protected override def nullSafeEval(input: Any): Any = cast(input)
+ override def genCode(ctx: CodegenContext): ExprCode = {
+ // If the cast does not change the structure, then we don't really need to cast anything.
+ // We can return what the children return. Same thing should happen in the interpreted path.
+ if (DataType.equalsStructurally(child.dataType, dataType)) {
+ child.genCode(ctx)
+ } else {
+ super.genCode(ctx)
+ }
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 520aff5e2b..30745c6a9d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -288,4 +288,30 @@ object DataType {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}
+
+ /**
+ * Returns true if the two data types share the same "shape", i.e. the types (including
+ * nullability) are the same, but the field names don't need to be the same.
+ */
+ def equalsStructurally(from: DataType, to: DataType): Boolean = {
+ (from, to) match {
+ case (left: ArrayType, right: ArrayType) =>
+ equalsStructurally(left.elementType, right.elementType) &&
+ left.containsNull == right.containsNull
+
+ case (left: MapType, right: MapType) =>
+ equalsStructurally(left.keyType, right.keyType) &&
+ equalsStructurally(left.valueType, right.valueType) &&
+ left.valueContainsNull == right.valueContainsNull
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields)
+ .forall { case (l, r) =>
+ equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
+ }
+
+ case (fromDataType, toDataType) => fromDataType == toDataType
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 8eccadbdd8..a7ffa884d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
}
+
+ test("SPARK-20302 cast with same structure") {
+ val from = new StructType()
+ .add("a", IntegerType)
+ .add("b", new StructType().add("b1", LongType))
+
+ val to = new StructType()
+ .add("a1", IntegerType)
+ .add("b1", new StructType().add("b11", LongType))
+
+ val input = Row(10, Row(12L))
+
+ checkEvaluation(cast(Literal.create(input, from), to), input)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index f078ef0133..c4635c8f12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite {
checkCatalogString(ArrayType(createStruct(40)))
checkCatalogString(MapType(IntegerType, StringType))
checkCatalogString(MapType(IntegerType, createStruct(40)))
+
+ def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
+ val testName = s"equalsStructurally: (from: $from, to: $to)"
+ test(testName) {
+ assert(DataType.equalsStructurally(from, to) === expected)
+ }
+ }
+
+ checkEqualsStructurally(BooleanType, BooleanType, true)
+ checkEqualsStructurally(IntegerType, IntegerType, true)
+ checkEqualsStructurally(IntegerType, LongType, false)
+ checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true)
+ checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false)
+
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType),
+ new StructType().add("f2", IntegerType),
+ true)
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType),
+ new StructType().add("f2", IntegerType, false),
+ false)
+
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
+ new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+ true)
+ checkEqualsStructurally(
+ new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
+ new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+ false)
}