diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-08-07 00:00:43 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-07 00:00:43 -0700 |
commit | e57d6b56137bf3557efe5acea3ad390c1987b257 (patch) | |
tree | 9cca56b04477dd27089bbc43534c26a5e2f79e57 | |
parent | 15bd6f338dff4bcab4a1a3a2c568655022e49c32 (diff) | |
download | spark-e57d6b56137bf3557efe5acea3ad390c1987b257.tar.gz spark-e57d6b56137bf3557efe5acea3ad390c1987b257.tar.bz2 spark-e57d6b56137bf3557efe5acea3ad390c1987b257.zip |
[SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe
When we convert unsafe row to safe row, we will do copy if the column is struct or string type. However, the string inside unsafe array/map are not copied, which may cause problems.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7990 from cloud-fan/copy and squashes the following commits:
c13d1e3 [Wenchen Fan] change test name
fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to safe row
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala | 3 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala | 38 |
2 files changed, 40 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala index 3caf0fb341..9b960b136f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class FromUnsafe(child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { @@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends UnaryExpression } new GenericArrayData(result) + case StringType => value.asInstanceOf[UTF8String].clone() + case MapType(kt, vt, _) => val map = value.asInstanceOf[UnsafeMapData] val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 707cd9c6d9..8208b25b57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.unsafe.types.UTF8String class RowFormatConvertersSuite extends SparkPlanTest { @@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest { input.map(Row.fromTuple) ) } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SparkPlan.currentContext.set(TestSQLContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // cache all strings to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) } |