aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala8
2 files changed, 17 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index ff97cd3211..6392ff42d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -130,6 +130,22 @@ class CodegenContext {
mutableStates += ((javaType, variableName, initCode))
}
+ /**
+ * Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
+ * that the variable is safely stored, which is important for (potentially) byte array backed
+ * data types like: UTF8String, ArrayData, MapData & InternalRow.
+ */
+ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
+ val value = freshName(variableName)
+ addMutableState(javaType(dataType), value, "")
+ val code = dataType match {
+ case StringType => s"$value = $initCode.clone();"
+ case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
+ case _ => s"$value = $initCode;"
+ }
+ ExprCode(code, "false", value)
+ }
+
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 32f0bc5bf9..fac6b8de8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -336,13 +336,7 @@ case class SortMergeJoinExec(
private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
vars.zipWithIndex.map { case (ev, i) =>
- val value = ctx.freshName("value")
- ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
- val code =
- s"""
- |$value = ${ev.value};
- """.stripMargin
- ExprCode(code, "false", value)
+ ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value)
}
}