diff options
author | Reynold Xin <rxin@apache.org> | 2014-09-01 20:32:31 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-09-01 20:32:31 -0700 |
commit | db160676c56de54efc4d42c6466847c2c3b6a963 (patch) | |
tree | 5988039eb8777d8e382a10b3bd023ce41abdb0c1 | |
parent | 1f98add92675eb34c80be9bd7d10ea4608a9a6c2 (diff) | |
download | spark-db160676c56de54efc4d42c6466847c2c3b6a963.tar.gz spark-db160676c56de54efc4d42c6466847c2c3b6a963.tar.bz2 spark-db160676c56de54efc4d42c6466847c2c3b6a963.zip |
[SPARK-3135] Avoid extra mem copy in TorrentBroadcast via ByteArrayChunkOutputStream
This also enables supporting broadcast variables larger than 2G.
Author: Reynold Xin <rxin@apache.org>
Closes #2054 from rxin/ByteArrayChunkOutputStream and squashes the following commits:
618d9c8 [Reynold Xin] Code review.
93f5a51 [Reynold Xin] Added comments.
ee88e73 [Reynold Xin] to -> until
bbd1cb1 [Reynold Xin] Renamed a variable.
36f4d01 [Reynold Xin] Sort imports.
8f1a8eb [Reynold Xin] [SPARK-3135] Created ByteArrayChunkOutputStream and used it to avoid memory copy in TorrentBroadcast.
3 files changed, 206 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 6173fd3a69..42d58682a1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.io.ByteArrayChunkOutputStream /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging { } def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { - // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks - // so we don't need to do the extra memory copy. - val bos = new ByteArrayOutputStream() + val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE) val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() - val byteArray = bos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt - val blocks = new Array[ByteBuffer](numBlocks) - - var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { - val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - val tempByteArray = new Array[Byte](thisBlockSize) - bais.read(tempByteArray, 0, thisBlockSize) - - blocks(blockId) = ByteBuffer.wrap(tempByteArray) - blockId += 1 - } - bais.close() - blocks + bos.toArrays.map(ByteBuffer.wrap) } def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala new file mode 100644 index 0000000000..daac6f971e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream + +import scala.collection.mutable.ArrayBuffer + + +/** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ +private[spark] +class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { + + private val chunks = new ArrayBuffer[Array[Byte]] + + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private var lastChunkIndex = -1 + + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private var position = chunkSize + + override def write(b: Int): Unit = { + allocateNewChunkIfNeeded() + chunks(lastChunkIndex)(position) = b.toByte + position += 1 + } + + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + var written = 0 + while (written < len) { + allocateNewChunkIfNeeded() + val thisBatch = math.min(chunkSize - position, len - written) + System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch) + written += thisBatch + position += thisBatch + } + } + + @inline + private def allocateNewChunkIfNeeded(): Unit = { + if (position == chunkSize) { + chunks += new Array[Byte](chunkSize) + lastChunkIndex += 1 + position = 0 + } + } + + def toArrays: Array[Array[Byte]] = { + if (lastChunkIndex == -1) { + new Array[Array[Byte]](0) + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + val ret = new Array[Array[Byte]](chunks.size) + for (i <- 0 until chunks.size - 1) { + ret(i) = chunks(i) + } + if (position == chunkSize) { + ret(lastChunkIndex) = chunks(lastChunkIndex) + } else { + ret(lastChunkIndex) = new Array[Byte](position) + System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position) + } + ret + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala new file mode 100644 index 0000000000..f855831b8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import scala.util.Random + +import org.scalatest.FunSuite + + +class ByteArrayChunkOutputStreamSuite extends FunSuite { + + test("empty output") { + val o = new ByteArrayChunkOutputStream(1024) + assert(o.toArrays.length === 0) + } + + test("write a single byte") { + val o = new ByteArrayChunkOutputStream(1024) + o.write(10) + assert(o.toArrays.length === 1) + assert(o.toArrays.head.toSeq === Seq(10.toByte)) + } + + test("write a single near boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](9)) + o.write(99) + assert(o.toArrays.length === 1) + assert(o.toArrays.head(9) === 99.toByte) + } + + test("write a single at boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](10)) + o.write(99) + assert(o.toArrays.length === 2) + assert(o.toArrays(1).length === 1) + assert(o.toArrays(1)(0) === 99.toByte) + } + + test("single chunk output") { + val ref = new Array[Byte](8) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("single chunk output at boundary size") { + val ref = new Array[Byte](10) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("multiple chunk output") { + val ref = new Array[Byte](26) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 6) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 26)) + } + + test("multiple chunk output at boundary size") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 10) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 30)) + } +} |