aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-07 00:00:43 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-07 00:00:43 -0700
commite57d6b56137bf3557efe5acea3ad390c1987b257 (patch)
tree9cca56b04477dd27089bbc43534c26a5e2f79e57 /sql
parent15bd6f338dff4bcab4a1a3a2c568655022e49c32 (diff)
downloadspark-e57d6b56137bf3557efe5acea3ad390c1987b257.tar.gz
spark-e57d6b56137bf3557efe5acea3ad390c1987b257.tar.bz2
spark-e57d6b56137bf3557efe5acea3ad390c1987b257.zip
[SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe
When we convert unsafe row to safe row, we will do copy if the column is struct or string type. However, the string inside unsafe array/map are not copied, which may cause problems. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7990 from cloud-fan/copy and squashes the following commits: c13d1e3 [Wenchen Fan] change test name fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to safe row
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala38
2 files changed, 40 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
index 3caf0fb341..9b960b136f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class FromUnsafe(child: Expression) extends UnaryExpression
with ExpectsInputTypes with CodegenFallback {
@@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends UnaryExpression
}
new GenericArrayData(result)
+ case StringType => value.asInstanceOf[UTF8String].clone()
+
case MapType(kt, vt, _) =>
val map = value.asInstanceOf[UnsafeMapData]
val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 707cd9c6d9..8208b25b57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql.execution
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
class RowFormatConvertersSuite extends SparkPlanTest {
@@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest {
input.map(Row.fromTuple)
)
}
+
+ test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
+ SparkPlan.currentContext.set(TestSQLContext)
+ val schema = ArrayType(StringType)
+ val rows = (1 to 100).map { i =>
+ InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
+ }
+ val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows)
+
+ val plan =
+ DummyPlan(
+ ConvertToSafe(
+ ConvertToUnsafe(relation)))
+ assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString))
+ }
+}
+
+case class DummyPlan(child: SparkPlan) extends UnaryNode {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ // cache all strings to make sure we have deep copied UTF8String inside incoming
+ // safe InternalRow.
+ val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
+ iter.foreach { row =>
+ strings += row.getArray(0).getUTF8String(0)
+ }
+ strings.map(InternalRow(_)).iterator
+ }
+ }
+
+ override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)())
}