diff options
author | Davies Liu <davies@databricks.com> | 2015-10-23 01:33:14 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-10-23 01:33:14 -0700 |
commit | 487d409e71767c76399217a07af8de1bb0da7aa8 (patch) | |
tree | 714042cab2bca2f575c5ca14b1e5113a0fec6eb1 /sql/catalyst | |
parent | 16dc9f344c08deee104090106cb0a537a90e33fc (diff) | |
download | spark-487d409e71767c76399217a07af8de1bb0da7aa8.tar.gz spark-487d409e71767c76399217a07af8de1bb0da7aa8.tar.bz2 spark-487d409e71767c76399217a07af8de1bb0da7aa8.zip |
[SPARK-11243][SQL] zero out padding bytes in UnsafeRow
For nested StructType, the underline buffer could be used for others before, we should zero out the padding bytes for those primitive types that have less than 8 bytes.
cc cloud-fan
Author: Davies Liu <davies@databricks.com>
Closes #9217 from davies/zero_out.
Diffstat (limited to 'sql/catalyst')
2 files changed, 35 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index adbe262187..048b7749d8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -100,19 +100,27 @@ public class UnsafeRowWriter { } public void write(int ordinal, boolean value) { - Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putBoolean(holder.buffer, offset, value); } public void write(int ordinal, byte value) { - Platform.putByte(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putByte(holder.buffer, offset, value); } public void write(int ordinal, short value) { - Platform.putShort(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putShort(holder.buffer, offset, value); } public void write(int ordinal, int value) { - Platform.putInt(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putInt(holder.buffer, offset, value); } public void write(int ordinal, long value) { @@ -123,7 +131,9 @@ public class UnsafeRowWriter { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putFloat(holder.buffer, offset, value); } public void write(int ordinal, double value) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 5adcac39c6..1522ee34e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -99,4 +99,24 @@ class GeneratedProjectionSuite extends SparkFunSuite { val row2 = safeProj(unsafeRow) assert(row2 === row) } + + test("padding bytes should be zeroed out") { + val types = Seq(BooleanType, ByteType, ShortType, IntegerType, FloatType, BinaryType, + StringType) + val struct = StructType(types.map(StructField("", _, true))) + val fields = Array[DataType](StringType, struct) + val unsafeProj = UnsafeProjection.create(fields) + + val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes, + UTF8String.fromString("")) + val row1 = InternalRow(UTF8String.fromString(""), innerRow) + val unsafe1 = unsafeProj(row1).copy() + // create a Row with long String before the inner struct + val row2 = InternalRow(UTF8String.fromString("a_long_string").repeat(10), innerRow) + val unsafe2 = unsafeProj(row2).copy() + assert(unsafe1.getStruct(1, 7) === unsafe2.getStruct(1, 7)) + val unsafe3 = unsafeProj(row1).copy() + assert(unsafe1 === unsafe3) + assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) + } } |