aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-09-01 20:32:31 -0700
committerReynold Xin <rxin@apache.org>2014-09-01 20:32:31 -0700
commitdb160676c56de54efc4d42c6466847c2c3b6a963 (patch)
tree5988039eb8777d8e382a10b3bd023ce41abdb0c1 /core
parent1f98add92675eb34c80be9bd7d10ea4608a9a6c2 (diff)
downloadspark-db160676c56de54efc4d42c6466847c2c3b6a963.tar.gz
spark-db160676c56de54efc4d42c6466847c2c3b6a963.tar.bz2
spark-db160676c56de54efc4d42c6466847c2c3b6a963.zip
[SPARK-3135] Avoid extra mem copy in TorrentBroadcast via ByteArrayChunkOutputStream
This also enables supporting broadcast variables larger than 2G. Author: Reynold Xin <rxin@apache.org> Closes #2054 from rxin/ByteArrayChunkOutputStream and squashes the following commits: 618d9c8 [Reynold Xin] Code review. 93f5a51 [Reynold Xin] Added comments. ee88e73 [Reynold Xin] to -> until bbd1cb1 [Reynold Xin] Renamed a variable. 36f4d01 [Reynold Xin] Sort imports. 8f1a8eb [Reynold Xin] [SPARK-3135] Created ByteArrayChunkOutputStream and used it to avoid memory copy in TorrentBroadcast.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala94
-rw-r--r--core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala109
3 files changed, 206 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 6173fd3a69..42d58682a1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
@@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging {
}
def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
- // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
- // so we don't need to do the extra memory copy.
- val bos = new ByteArrayOutputStream()
+ val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
- val byteArray = bos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
- val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
- val blocks = new Array[ByteBuffer](numBlocks)
-
- var blockId = 0
- for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
- val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
- val tempByteArray = new Array[Byte](thisBlockSize)
- bais.read(tempByteArray, 0, thisBlockSize)
-
- blocks(blockId) = ByteBuffer.wrap(tempByteArray)
- blockId += 1
- }
- bais.close()
- blocks
+ bos.toArrays.map(ByteBuffer.wrap)
}
def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
new file mode 100644
index 0000000000..daac6f971e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * An OutputStream that writes to fixed-size chunks of byte arrays.
+ *
+ * @param chunkSize size of each chunk, in bytes.
+ */
+private[spark]
+class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
+
+ private val chunks = new ArrayBuffer[Array[Byte]]
+
+ /** Index of the last chunk. Starting with -1 when the chunks array is empty. */
+ private var lastChunkIndex = -1
+
+ /**
+ * Next position to write in the last chunk.
+ *
+ * If this equals chunkSize, it means for next write we need to allocate a new chunk.
+ * This can also never be 0.
+ */
+ private var position = chunkSize
+
+ override def write(b: Int): Unit = {
+ allocateNewChunkIfNeeded()
+ chunks(lastChunkIndex)(position) = b.toByte
+ position += 1
+ }
+
+ override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ var written = 0
+ while (written < len) {
+ allocateNewChunkIfNeeded()
+ val thisBatch = math.min(chunkSize - position, len - written)
+ System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
+ written += thisBatch
+ position += thisBatch
+ }
+ }
+
+ @inline
+ private def allocateNewChunkIfNeeded(): Unit = {
+ if (position == chunkSize) {
+ chunks += new Array[Byte](chunkSize)
+ lastChunkIndex += 1
+ position = 0
+ }
+ }
+
+ def toArrays: Array[Array[Byte]] = {
+ if (lastChunkIndex == -1) {
+ new Array[Array[Byte]](0)
+ } else {
+ // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
+ // An alternative would have been returning an array of ByteBuffers, with the last buffer
+ // bounded to only the last chunk's position. However, given our use case in Spark (to put
+ // the chunks in block manager), only limiting the view bound of the buffer would still
+ // require the block manager to store the whole chunk.
+ val ret = new Array[Array[Byte]](chunks.size)
+ for (i <- 0 until chunks.size - 1) {
+ ret(i) = chunks(i)
+ }
+ if (position == chunkSize) {
+ ret(lastChunkIndex) = chunks(lastChunkIndex)
+ } else {
+ ret(lastChunkIndex) = new Array[Byte](position)
+ System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
+ }
+ ret
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
new file mode 100644
index 0000000000..f855831b8e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+
+class ByteArrayChunkOutputStreamSuite extends FunSuite {
+
+ test("empty output") {
+ val o = new ByteArrayChunkOutputStream(1024)
+ assert(o.toArrays.length === 0)
+ }
+
+ test("write a single byte") {
+ val o = new ByteArrayChunkOutputStream(1024)
+ o.write(10)
+ assert(o.toArrays.length === 1)
+ assert(o.toArrays.head.toSeq === Seq(10.toByte))
+ }
+
+ test("write a single near boundary") {
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(new Array[Byte](9))
+ o.write(99)
+ assert(o.toArrays.length === 1)
+ assert(o.toArrays.head(9) === 99.toByte)
+ }
+
+ test("write a single at boundary") {
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(new Array[Byte](10))
+ o.write(99)
+ assert(o.toArrays.length === 2)
+ assert(o.toArrays(1).length === 1)
+ assert(o.toArrays(1)(0) === 99.toByte)
+ }
+
+ test("single chunk output") {
+ val ref = new Array[Byte](8)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 1)
+ assert(arrays.head.length === ref.length)
+ assert(arrays.head.toSeq === ref.toSeq)
+ }
+
+ test("single chunk output at boundary size") {
+ val ref = new Array[Byte](10)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 1)
+ assert(arrays.head.length === ref.length)
+ assert(arrays.head.toSeq === ref.toSeq)
+ }
+
+ test("multiple chunk output") {
+ val ref = new Array[Byte](26)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 3)
+ assert(arrays(0).length === 10)
+ assert(arrays(1).length === 10)
+ assert(arrays(2).length === 6)
+
+ assert(arrays(0).toSeq === ref.slice(0, 10))
+ assert(arrays(1).toSeq === ref.slice(10, 20))
+ assert(arrays(2).toSeq === ref.slice(20, 26))
+ }
+
+ test("multiple chunk output at boundary size") {
+ val ref = new Array[Byte](30)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 3)
+ assert(arrays(0).length === 10)
+ assert(arrays(1).length === 10)
+ assert(arrays(2).length === 10)
+
+ assert(arrays(0).toSeq === ref.slice(0, 10))
+ assert(arrays(1).toSeq === ref.slice(10, 20))
+ assert(arrays(2).toSeq === ref.slice(20, 30))
+ }
+}