diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-08-03 04:21:15 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-03 04:21:15 -0700 |
commit | 137f47865df6e98ab70ae5ba30dc4d441fb41166 (patch) | |
tree | caf945fa544fefe5e4157cf8283f5dc738575cfe /sql | |
parent | 95dccc63350c45045f038bab9f8a5080b4e1f8cc (diff) | |
download | spark-137f47865df6e98ab70ae5ba30dc4d441fb41166.tar.gz spark-137f47865df6e98ab70ae5ba30dc4d441fb41166.tar.bz2 spark-137f47865df6e98ab70ae5ba30dc4d441fb41166.zip |
[SPARK-9551][SQL] add a cheap version of copy for UnsafeRow to reuse a copy buffer
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7885 from cloud-fan/cheap-copy and squashes the following commits:
0900ca1 [Wenchen Fan] replace == with ===
73f4ada [Wenchen Fan] add tests
07b865a [Wenchen Fan] add a cheap version of copy
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 32 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | 38 |
2 files changed, 70 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c5d42d73a4..f4230cfaba 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -464,6 +464,38 @@ public final class UnsafeRow extends MutableRow { } /** + * Creates an empty UnsafeRow from a byte array with specified numBytes and numFields. + * The returned row is invalid until we call copyFrom on it. + */ + public static UnsafeRow createFromByteArray(int numBytes, int numFields) { + final UnsafeRow row = new UnsafeRow(); + row.pointTo(new byte[numBytes], numFields, numBytes); + return row; + } + + /** + * Copies the input UnsafeRow to this UnsafeRow, and resize the underlying byte[] when the + * input row is larger than this row. + */ + public void copyFrom(UnsafeRow row) { + // copyFrom is only available for UnsafeRow created from byte array. + assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + if (row.sizeInBytes > this.sizeInBytes) { + // resize the underlying byte[] if it's not large enough. + this.baseObject = new byte[row.sizeInBytes]; + } + PlatformDependent.copyMemory( + row.baseObject, + row.baseOffset, + this.baseObject, + this.baseOffset, + row.sizeInBytes + ); + // update the sizeInBytes. + this.sizeInBytes = row.sizeInBytes; + } + + /** * Write this UnsafeRow's underlying bytes to the given OutputStream. * * @param out the stream to write to. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index e72a1bc6c4..c5faaa663e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -82,4 +82,42 @@ class UnsafeRowSuite extends SparkFunSuite { assert(unsafeRow.get(0, dataType) === null) } } + + test("createFromByteArray and copyFrom") { + val row = InternalRow(1, UTF8String.fromString("abc")) + val converter = UnsafeProjection.create(Array[DataType](IntegerType, StringType)) + val unsafeRow = converter.apply(row) + + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val buffer = emptyRow.getBaseObject + + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + // make sure we reuse the buffer. + assert(emptyRow.getBaseObject === buffer) + + // make sure we really copied the input row. + unsafeRow.setInt(0, 2) + assert(emptyRow.getInt(0) === 1) + + val longString = UTF8String.fromString((1 to 100).map(_ => "abc").reduce(_ + _)) + val row2 = InternalRow(3, longString) + val unsafeRow2 = converter.apply(row2) + + // make sure we can resize. + emptyRow.copyFrom(unsafeRow2) + assert(emptyRow.getSizeInBytes() === unsafeRow2.getSizeInBytes) + assert(emptyRow.getInt(0) === 3) + assert(emptyRow.getUTF8String(1) === longString) + // make sure we really resized. + assert(emptyRow.getBaseObject != buffer) + + // make sure we can still handle small rows after resize. + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + } } |