aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorLiwei Lin <lwlin7@gmail.com>2016-08-25 11:24:40 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-08-25 11:24:40 +0200
commite0b20f9f24d5c3304bf517a4dcfb0da93be5bc75 (patch)
treef59230d7e3c65c874647bd09c37c4decbd23f7d6 /sql/catalyst
parent2bcd5d5ce3eaf0eb1600a12a2b55ddb40927533b (diff)
downloadspark-e0b20f9f24d5c3304bf517a4dcfb0da93be5bc75.tar.gz
spark-e0b20f9f24d5c3304bf517a4dcfb0da93be5bc75.tar.bz2
spark-e0b20f9f24d5c3304bf517a4dcfb0da93be5bc75.zip
[SPARK-17061][SPARK-17093][SQL] MapObjects` should make copies of unsafe-backed data
## What changes were proposed in this pull request? Currently `MapObjects` does not make copies of unsafe-backed data, leading to problems like [SPARK-17061](https://issues.apache.org/jira/browse/SPARK-17061) [SPARK-17093](https://issues.apache.org/jira/browse/SPARK-17093). This patch makes `MapObjects` make copies of unsafe-backed data. Generated code - prior to this patch: ```java ... /* 295 */ if (isNull12) { /* 296 */ convertedArray1[loopIndex1] = null; /* 297 */ } else { /* 298 */ convertedArray1[loopIndex1] = value12; /* 299 */ } ... ``` Generated code - after this patch: ```java ... /* 295 */ if (isNull12) { /* 296 */ convertedArray1[loopIndex1] = null; /* 297 */ } else { /* 298 */ convertedArray1[loopIndex1] = value12 instanceof UnsafeRow? value12.copy() : value12; /* 299 */ } ... ``` ## How was this patch tested? Add a new test case which would fail without this patch. Author: Liwei Lin <lwlin7@gmail.com> Closes #14698 from lw-lin/mapobjects-copy.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala34
3 files changed, 46 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 31ed485317..4da74a0a27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -494,6 +494,16 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}
+ // Make a copy of the data if it's unsafe-backed
+ def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
+ s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
+ val genFunctionValue = lambdaFunction.dataType match {
+ case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
+ case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
+ case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
+ case _ => genFunction.value
+ }
+
val loopNullCheck = inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
@@ -521,7 +531,7 @@ case class MapObjects private(
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
- $convertedArray[$loopIndex] = ${genFunction.value};
+ $convertedArray[$loopIndex] = $genFunctionValue;
}
$loopIndex += 1;
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index d6a9672d1f..668543a28b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -136,7 +136,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
val plan = generateProject(
- GenerateUnsafeProjection.generate(
+ UnsafeProjection.create(
Alias(expression, s"Optimized($expression)1")() ::
Alias(expression, s"Optimized($expression)2")() :: Nil),
expression)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index ee65826cd5..3edcc02f15 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}
@@ -32,4 +34,36 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val invoke = Invoke(inputObject, "_2", IntegerType)
checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
}
+
+ test("MapObjects should make copies of unsafe-backed data") {
+ // test UnsafeRow-backed data
+ val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
+ val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
+ val structExpected = new GenericArrayData(
+ Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
+ checkEvalutionWithUnsafeProjection(
+ structEncoder.serializer.head, structExpected, structInputRow)
+
+ // test UnsafeArray-backed data
+ val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
+ val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
+ val arrayExpected = new GenericArrayData(
+ Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
+ checkEvalutionWithUnsafeProjection(
+ arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
+
+ // test UnsafeMap-backed data
+ val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
+ val mapInputRow = InternalRow.fromSeq(Seq(Array(
+ Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
+ val mapExpected = new GenericArrayData(Seq(
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(1, 2)),
+ new GenericArrayData(Array(100, 200))),
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(3, 4)),
+ new GenericArrayData(Array(300, 400)))))
+ checkEvalutionWithUnsafeProjection(
+ mapEncoder.serializer.head, mapExpected, mapInputRow)
+ }
}