aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-20 23:28:35 -0700
committerReynold Xin <rxin@databricks.com>2015-07-20 23:28:35 -0700
commit48f8fd46b32973f1f3b865da80345698cb1a71c7 (patch)
tree8786cf9fce6aef663124662720e66addad2fe13a
parent67570beed5950974126a91eacd48fd0fedfeb141 (diff)
downloadspark-48f8fd46b32973f1f3b865da80345698cb1a71c7.tar.gz
spark-48f8fd46b32973f1f3b865da80345698cb1a71c7.tar.bz2
spark-48f8fd46b32973f1f3b865da80345698cb1a71c7.zip
[SPARK-9023] [SQL] Followup for #7456 (Efficiency improvements for UnsafeRows in Exchange)
This patch addresses code review feedback from #7456. Author: Josh Rosen <joshrosen@databricks.com> Closes #7551 from JoshRosen/unsafe-exchange-followup and squashes the following commits: 76dbdf8 [Josh Rosen] Add comments + more methods to UnsafeRowSerializer 3d7a1f2 [Josh Rosen] Add writeToStream() method to UnsafeRow
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala80
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala71
3 files changed, 161 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 8cd9e7bc60..6ce03a48e9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions;
+import java.io.IOException;
+import java.io.OutputStream;
+
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
@@ -371,6 +374,36 @@ public final class UnsafeRow extends MutableRow {
}
}
+ /**
+ * Write this UnsafeRow's underlying bytes to the given OutputStream.
+ *
+ * @param out the stream to write to.
+ * @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the
+ * output stream. If this row is backed by an on-heap byte array, then this
+ * buffer will not be used and may be null.
+ */
+ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
+ if (baseObject instanceof byte[]) {
+ int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset);
+ out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
+ } else {
+ int dataRemaining = sizeInBytes;
+ long rowReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ int toTransfer = Math.min(writeBuffer.length, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ rowReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ toTransfer);
+ out.write(writeBuffer, 0, toTransfer);
+ rowReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ }
+ }
+
@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 19503ed000..318550e5ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {
+ /**
+ * Marks the end of a stream written with [[serializeStream()]].
+ */
private[this] val EOF: Int = -1
+ /**
+ * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
+ * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
+ * The end of the stream is denoted by a record with the special length `EOF` (-1).
+ */
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
@@ -59,32 +67,31 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
val row = value.asInstanceOf[UnsafeRow]
assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
dOut.writeInt(row.getSizeInBytes)
- var dataRemaining: Int = row.getSizeInBytes
- val baseObject = row.getBaseObject
- var rowReadPosition: Long = row.getBaseOffset
- while (dataRemaining > 0) {
- val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
- PlatformDependent.copyMemory(
- baseObject,
- rowReadPosition,
- writeBuffer,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- toTransfer)
- out.write(writeBuffer, 0, toTransfer)
- rowReadPosition += toTransfer
- dataRemaining -= toTransfer
- }
+ row.writeToStream(out, writeBuffer)
this
}
+
override def writeKey[T: ClassTag](key: T): SerializationStream = {
+ // The key is only needed on the map side when computing partition ids. It does not need to
+ // be shuffled.
assert(key.isInstanceOf[Int])
this
}
- override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =
+
+ override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
+ // This method is never called by shuffle code.
throw new UnsupportedOperationException
- override def writeObject[T: ClassTag](t: T): SerializationStream =
+ }
+
+ override def writeObject[T: ClassTag](t: T): SerializationStream = {
+ // This method is never called by shuffle code.
throw new UnsupportedOperationException
- override def flush(): Unit = dOut.flush()
+ }
+
+ override def flush(): Unit = {
+ dOut.flush()
+ }
+
override def close(): Unit = {
writeBuffer = null
dOut.writeInt(EOF)
@@ -95,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(in)
+ // 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
@@ -126,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
}
}
}
- override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
- override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
- override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
- override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
- override def close(): Unit = dIn.close()
+
+ override def asIterator: Iterator[Any] = {
+ // This method is never called by shuffle code.
+ throw new UnsupportedOperationException
+ }
+
+ override def readKey[T: ClassTag](): T = {
+ // We skipped serialization of the key in writeKey(), so just return a dummy value since
+ // this is going to be discarded anyways.
+ null.asInstanceOf[T]
+ }
+
+ override def readValue[T: ClassTag](): T = {
+ val rowSize = dIn.readInt()
+ if (rowBuffer.length < rowSize) {
+ rowBuffer = new Array[Byte](rowSize)
+ }
+ ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+ row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
+ row.asInstanceOf[T]
+ }
+
+ override def readObject[T: ClassTag](): T = {
+ // This method is never called by shuffle code.
+ throw new UnsupportedOperationException
+ }
+
+ override def close(): Unit = {
+ dIn.close()
+ }
}
}
+ // These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
new file mode 100644
index 0000000000..3854dc1b7a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.ByteArrayOutputStream
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
+import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.memory.MemoryAllocator
+import org.apache.spark.unsafe.types.UTF8String
+
+class UnsafeRowSuite extends SparkFunSuite {
+ test("writeToStream") {
+ val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
+ val arrayBackedUnsafeRow: UnsafeRow =
+ UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
+ assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
+ val bytesFromArrayBackedRow: Array[Byte] = {
+ val baos = new ByteArrayOutputStream()
+ arrayBackedUnsafeRow.writeToStream(baos, null)
+ baos.toByteArray
+ }
+ val bytesFromOffheapRow: Array[Byte] = {
+ val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes)
+ try {
+ PlatformDependent.copyMemory(
+ arrayBackedUnsafeRow.getBaseObject,
+ arrayBackedUnsafeRow.getBaseOffset,
+ offheapRowPage.getBaseObject,
+ offheapRowPage.getBaseOffset,
+ arrayBackedUnsafeRow.getSizeInBytes
+ )
+ val offheapUnsafeRow: UnsafeRow = new UnsafeRow()
+ offheapUnsafeRow.pointTo(
+ offheapRowPage.getBaseObject,
+ offheapRowPage.getBaseOffset,
+ 3, // num fields
+ arrayBackedUnsafeRow.getSizeInBytes,
+ null // object pool
+ )
+ assert(offheapUnsafeRow.getBaseObject === null)
+ val baos = new ByteArrayOutputStream()
+ val writeBuffer = new Array[Byte](1024)
+ offheapUnsafeRow.writeToStream(baos, writeBuffer)
+ baos.toByteArray
+ } finally {
+ MemoryAllocator.UNSAFE.free(offheapRowPage)
+ }
+ }
+
+ assert(bytesFromArrayBackedRow === bytesFromOffheapRow)
+ }
+}