aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-03 04:21:15 -0700
committerReynold Xin <rxin@databricks.com>2015-08-03 04:21:15 -0700
commit137f47865df6e98ab70ae5ba30dc4d441fb41166 (patch)
treecaf945fa544fefe5e4157cf8283f5dc738575cfe /sql
parent95dccc63350c45045f038bab9f8a5080b4e1f8cc (diff)
downloadspark-137f47865df6e98ab70ae5ba30dc4d441fb41166.tar.gz
spark-137f47865df6e98ab70ae5ba30dc4d441fb41166.tar.bz2
spark-137f47865df6e98ab70ae5ba30dc4d441fb41166.zip
[SPARK-9551][SQL] add a cheap version of copy for UnsafeRow to reuse a copy buffer
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7885 from cloud-fan/cheap-copy and squashes the following commits: 0900ca1 [Wenchen Fan] replace == with === 73f4ada [Wenchen Fan] add tests 07b865a [Wenchen Fan] add a cheap version of copy
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala38
2 files changed, 70 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index c5d42d73a4..f4230cfaba 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -464,6 +464,38 @@ public final class UnsafeRow extends MutableRow {
}
/**
+ * Creates an empty UnsafeRow from a byte array with specified numBytes and numFields.
+ * The returned row is invalid until we call copyFrom on it.
+ */
+ public static UnsafeRow createFromByteArray(int numBytes, int numFields) {
+ final UnsafeRow row = new UnsafeRow();
+ row.pointTo(new byte[numBytes], numFields, numBytes);
+ return row;
+ }
+
+ /**
+ * Copies the input UnsafeRow to this UnsafeRow, and resize the underlying byte[] when the
+ * input row is larger than this row.
+ */
+ public void copyFrom(UnsafeRow row) {
+ // copyFrom is only available for UnsafeRow created from byte array.
+ assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET;
+ if (row.sizeInBytes > this.sizeInBytes) {
+ // resize the underlying byte[] if it's not large enough.
+ this.baseObject = new byte[row.sizeInBytes];
+ }
+ PlatformDependent.copyMemory(
+ row.baseObject,
+ row.baseOffset,
+ this.baseObject,
+ this.baseOffset,
+ row.sizeInBytes
+ );
+ // update the sizeInBytes.
+ this.sizeInBytes = row.sizeInBytes;
+ }
+
+ /**
* Write this UnsafeRow's underlying bytes to the given OutputStream.
*
* @param out the stream to write to.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index e72a1bc6c4..c5faaa663e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -82,4 +82,42 @@ class UnsafeRowSuite extends SparkFunSuite {
assert(unsafeRow.get(0, dataType) === null)
}
}
+
+ test("createFromByteArray and copyFrom") {
+ val row = InternalRow(1, UTF8String.fromString("abc"))
+ val converter = UnsafeProjection.create(Array[DataType](IntegerType, StringType))
+ val unsafeRow = converter.apply(row)
+
+ val emptyRow = UnsafeRow.createFromByteArray(64, 2)
+ val buffer = emptyRow.getBaseObject
+
+ emptyRow.copyFrom(unsafeRow)
+ assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes)
+ assert(emptyRow.getInt(0) === unsafeRow.getInt(0))
+ assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1))
+ // make sure we reuse the buffer.
+ assert(emptyRow.getBaseObject === buffer)
+
+ // make sure we really copied the input row.
+ unsafeRow.setInt(0, 2)
+ assert(emptyRow.getInt(0) === 1)
+
+ val longString = UTF8String.fromString((1 to 100).map(_ => "abc").reduce(_ + _))
+ val row2 = InternalRow(3, longString)
+ val unsafeRow2 = converter.apply(row2)
+
+ // make sure we can resize.
+ emptyRow.copyFrom(unsafeRow2)
+ assert(emptyRow.getSizeInBytes() === unsafeRow2.getSizeInBytes)
+ assert(emptyRow.getInt(0) === 3)
+ assert(emptyRow.getUTF8String(1) === longString)
+ // make sure we really resized.
+ assert(emptyRow.getBaseObject != buffer)
+
+ // make sure we can still handle small rows after resize.
+ emptyRow.copyFrom(unsafeRow)
+ assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes)
+ assert(emptyRow.getInt(0) === unsafeRow.getInt(0))
+ assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1))
+ }
}