aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-23 01:33:14 -0700
committerReynold Xin <rxin@databricks.com>2015-10-23 01:33:14 -0700
commit487d409e71767c76399217a07af8de1bb0da7aa8 (patch)
tree714042cab2bca2f575c5ca14b1e5113a0fec6eb1 /sql
parent16dc9f344c08deee104090106cb0a537a90e33fc (diff)
downloadspark-487d409e71767c76399217a07af8de1bb0da7aa8.tar.gz
spark-487d409e71767c76399217a07af8de1bb0da7aa8.tar.bz2
spark-487d409e71767c76399217a07af8de1bb0da7aa8.zip
[SPARK-11243][SQL] zero out padding bytes in UnsafeRow
For nested StructType, the underline buffer could be used for others before, we should zero out the padding bytes for those primitive types that have less than 8 bytes. cc cloud-fan Author: Davies Liu <davies@databricks.com> Closes #9217 from davies/zero_out.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala20
2 files changed, 35 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index adbe262187..048b7749d8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -100,19 +100,27 @@ public class UnsafeRowWriter {
}
public void write(int ordinal, boolean value) {
- Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value);
+ final long offset = getFieldOffset(ordinal);
+ Platform.putLong(holder.buffer, offset, 0L);
+ Platform.putBoolean(holder.buffer, offset, value);
}
public void write(int ordinal, byte value) {
- Platform.putByte(holder.buffer, getFieldOffset(ordinal), value);
+ final long offset = getFieldOffset(ordinal);
+ Platform.putLong(holder.buffer, offset, 0L);
+ Platform.putByte(holder.buffer, offset, value);
}
public void write(int ordinal, short value) {
- Platform.putShort(holder.buffer, getFieldOffset(ordinal), value);
+ final long offset = getFieldOffset(ordinal);
+ Platform.putLong(holder.buffer, offset, 0L);
+ Platform.putShort(holder.buffer, offset, value);
}
public void write(int ordinal, int value) {
- Platform.putInt(holder.buffer, getFieldOffset(ordinal), value);
+ final long offset = getFieldOffset(ordinal);
+ Platform.putLong(holder.buffer, offset, 0L);
+ Platform.putInt(holder.buffer, offset, value);
}
public void write(int ordinal, long value) {
@@ -123,7 +131,9 @@ public class UnsafeRowWriter {
if (Float.isNaN(value)) {
value = Float.NaN;
}
- Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value);
+ final long offset = getFieldOffset(ordinal);
+ Platform.putLong(holder.buffer, offset, 0L);
+ Platform.putFloat(holder.buffer, offset, value);
}
public void write(int ordinal, double value) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index 5adcac39c6..1522ee34e4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -99,4 +99,24 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val row2 = safeProj(unsafeRow)
assert(row2 === row)
}
+
+ test("padding bytes should be zeroed out") {
+ val types = Seq(BooleanType, ByteType, ShortType, IntegerType, FloatType, BinaryType,
+ StringType)
+ val struct = StructType(types.map(StructField("", _, true)))
+ val fields = Array[DataType](StringType, struct)
+ val unsafeProj = UnsafeProjection.create(fields)
+
+ val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes,
+ UTF8String.fromString(""))
+ val row1 = InternalRow(UTF8String.fromString(""), innerRow)
+ val unsafe1 = unsafeProj(row1).copy()
+ // create a Row with long String before the inner struct
+ val row2 = InternalRow(UTF8String.fromString("a_long_string").repeat(10), innerRow)
+ val unsafe2 = unsafeProj(row2).copy()
+ assert(unsafe1.getStruct(1, 7) === unsafe2.getStruct(1, 7))
+ val unsafe3 = unsafeProj(row1).copy()
+ assert(unsafe1 === unsafe3)
+ assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
+ }
}