aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-08 20:28:05 -0700
committerReynold Xin <rxin@databricks.com>2015-07-08 20:28:05 -0700
commitb55499a44ab74e33378211fb0d6940905d7c6318 (patch)
tree9c60c12645458ccc8a7c76b2d7e3bbc791752c44 /sql
parenta290814877308c6fa9b0f78b1a81145db7651ca4 (diff)
downloadspark-b55499a44ab74e33378211fb0d6940905d7c6318.tar.gz
spark-b55499a44ab74e33378211fb0d6940905d7c6318.tar.bz2
spark-b55499a44ab74e33378211fb0d6940905d7c6318.zip
[SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools
We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called. Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format. In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used. This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array. Author: Josh Rosen <joshrosen@databricks.com> Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits: 338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java12
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala52
4 files changed, 87 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 1e79f4b2e8..79d55b36da 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -120,9 +120,11 @@ public final class UnsafeFixedWidthAggregationMap {
this.bufferPool = new ObjectPool(initialCapacity);
InternalRow initRow = initProjection.apply(emptyRow);
- this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+ int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
+ this.emptyBuffer = new byte[emptyBufferSize];
int writtenLength = bufferConverter.writeRow(
- initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+ initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
+ bufferPool);
assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
// re-use the empty buffer only when there is no object saved in pool.
reuseEmptyBuffer = bufferPool.size() == 0;
@@ -142,6 +144,7 @@ public final class UnsafeFixedWidthAggregationMap {
groupingKey,
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
+ groupingKeySize,
keyPool);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
@@ -157,7 +160,7 @@ public final class UnsafeFixedWidthAggregationMap {
// There is some objects referenced by emptyBuffer, so generate a new one
InternalRow initRow = initProjection.apply(emptyRow);
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
- bufferPool);
+ groupingKeySize, bufferPool);
}
loc.putNewKey(
groupingKeyConversionScratchSpace,
@@ -175,6 +178,7 @@ public final class UnsafeFixedWidthAggregationMap {
address.getBaseObject(),
address.getBaseOffset(),
bufferConverter.numFields(),
+ loc.getValueLength(),
bufferPool
);
return currentBuffer;
@@ -214,12 +218,14 @@ public final class UnsafeFixedWidthAggregationMap {
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
keyConverter.numFields(),
+ loc.getKeyLength(),
keyPool
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
bufferConverter.numFields(),
+ loc.getValueLength(),
bufferPool
);
return entry;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index aeb64b0458..edb7202245 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow {
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
+ /** The size of this row's backing data, in bytes) */
+ private int sizeInBytes;
+
public int length() { return numFields; }
/** The width of the null tracking bit set, in bytes */
@@ -95,14 +98,17 @@ public final class UnsafeRow extends MutableRow {
* @param baseObject the base object
* @param baseOffset the offset within the base object
* @param numFields the number of fields in this row
+ * @param sizeInBytes the size of this row's backing data, in bytes
* @param pool the object pool to hold arbitrary objects
*/
- public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
+ public void pointTo(
+ Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
assert numFields >= 0 : "numFields should >= 0";
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
+ this.sizeInBytes = sizeInBytes;
this.pool = pool;
}
@@ -336,9 +342,31 @@ public final class UnsafeRow extends MutableRow {
}
}
+ /**
+ * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
+ * byte array rather than referencing data stored in a data page.
+ * <p>
+ * This method is only supported on UnsafeRows that do not use ObjectPools.
+ */
@Override
public InternalRow copy() {
- throw new UnsupportedOperationException();
+ if (pool != null) {
+ throw new UnsupportedOperationException(
+ "Copy is not supported for UnsafeRows that use object pools");
+ } else {
+ UnsafeRow rowCopy = new UnsafeRow();
+ final byte[] rowDataCopy = new byte[sizeInBytes];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ rowDataCopy,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ sizeInBytes
+ );
+ rowCopy.pointTo(
+ rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
+ return rowCopy;
+ }
}
@Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 1f395497a9..6af5e6200e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
* @param row the row to convert
* @param baseObject the base object of the destination address
* @param baseOffset the base offset of the destination address
+ * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
*/
- def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
- unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
+ def writeRow(
+ row: InternalRow,
+ baseObject: Object,
+ baseOffset: Long,
+ rowLengthInBytes: Int,
+ pool: ObjectPool): Int = {
+ unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool)
if (writers.length > 0) {
// zero-out the bitset
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 96d4e64ea3..d00aeb4dfb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(row)
assert(sizeRequired === 8 + (3 * 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ val numBytesWritten =
+ converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.pointTo(
+ buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
+ // We can copy UnsafeRows as long as they don't reference ObjectPools
+ val unsafeRowCopy = unsafeRow.copy()
+ assert(unsafeRowCopy.getLong(0) === 0)
+ assert(unsafeRowCopy.getLong(1) === 1)
+ assert(unsafeRowCopy.getInt(2) === 2)
+
unsafeRow.setLong(1, 3)
assert(unsafeRow.getLong(1) === 3)
unsafeRow.setInt(2, 4)
assert(unsafeRow.getInt(2) === 4)
+
+ // Mutating the original row should not have changed the copy
+ assert(unsafeRowCopy.getLong(0) === 0)
+ assert(unsafeRowCopy.getLong(1) === 1)
+ assert(unsafeRowCopy.getInt(2) === 2)
}
test("basic conversion with primitive, string and binary types") {
@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ val numBytesWritten = converter.writeRow(
+ row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
val pool = new ObjectPool(10)
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+ unsafeRow.pointTo(
+ buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
assert(unsafeRow.get(2) === "World".getBytes)
@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.update(2, "Hello World".getBytes)
assert(unsafeRow.get(2) === "Hello World".getBytes)
assert(pool.size === 2)
+
+ // We do not support copy() for UnsafeRows that reference ObjectPools
+ intercept[UnsupportedOperationException] {
+ unsafeRow.copy()
+ }
}
test("basic conversion with primitive, decimal and array") {
@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(row)
assert(sizeRequired === 8 + (8 * 3))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+ val numBytesWritten =
+ converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool)
assert(numBytesWritten === sizeRequired)
assert(pool.size === 2)
val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+ unsafeRow.pointTo(
+ buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.get(1) === Decimal(1))
assert(unsafeRow.get(2) === Array(2))
@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(sizeRequired === 8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ val numBytesWritten =
+ converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.pointTo(
+ buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
// Date is represented as Int in unsafeRow
@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(
- rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
+ sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val createdFromNull = new UnsafeRow()
createdFromNull.pointTo(
- createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
+ sizeRequired, null)
for (i <- 0 to fieldTypes.length - 1) {
assert(createdFromNull.isNullAt(i))
}
@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val pool = new ObjectPool(1)
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
converter.writeRow(
- rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
+ sizeRequired, pool)
val setToNullAfterCreation = new UnsafeRow()
setToNullAfterCreation.pointTo(
- setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+ setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
+ sizeRequired, pool)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))