From bd11b01ebaf62df8b0d8c0b63b51b66e58f50960 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 27 May 2015 22:23:22 -0700 Subject: [SPARK-7896] Allow ChainedBuffer to store more than 2 GB Author: Sandy Ryza Closes #6440 from sryza/sandy-spark-7896 and squashes the following commits: 49d8a0d [Sandy Ryza] Fix bug introduced when reading over record boundaries 6006856 [Sandy Ryza] Fix overflow issues 006b4b2 [Sandy Ryza] Fix scalastyle by removing non ascii characters 8b000ca [Sandy Ryza] Add ascii art to describe layout of data in metaBuffer f2053c0 [Sandy Ryza] Fix negative overflow issue 0368c78 [Sandy Ryza] Initialize size as 0 a5a4820 [Sandy Ryza] Use explicit types for all numbers in ChainedBuffer b7e0213 [Sandy Ryza] SPARK-7896. Allow ChainedBuffer to store more than 2 GB --- .../spark/util/collection/ChainedBuffer.scala | 46 +++++++++---------- .../PartitionedSerializedPairBuffer.scala | 51 +++++++++++++--------- 2 files changed, 55 insertions(+), 42 deletions(-) (limited to 'core/src') diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index a60bffe611..516aaa44d0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -28,11 +28,13 @@ import scala.collection.mutable.ArrayBuffer * occupy a contiguous segment of memory. */ private[spark] class ChainedBuffer(chunkSize: Int) { - private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt - assert(math.pow(2, chunkSizeLog2).toInt == chunkSize, + + private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( + java.lang.Long.highestOneBit(chunkSize)) + assert((1 << chunkSizeLog2) == chunkSize, s"ChainedBuffer chunk size $chunkSize must be a power of two") private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Int = _ + private var _size: Long = 0 /** * Feed bytes from this buffer into a BlockObjectWriter. @@ -41,16 +43,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param os OutputStream to read into. * @param len Number of bytes to read. */ - def read(pos: Int, os: OutputStream, len: Int): Unit = { + def read(pos: Long, os: OutputStream, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size ${_size} of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) os.write(chunks(chunkIndex), posInChunk, toRead) written += toRead chunkIndex += 1 @@ -66,16 +68,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to read to. * @param len Number of bytes to read. */ - def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) written += toRead chunkIndex += 1 @@ -91,22 +93,22 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to write from. * @param len Number of bytes to write. */ - def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos > _size) { throw new IndexOutOfBoundsException( s"Write at position $pos starts after end of buffer ${_size}") } // Grow if needed - val endChunkIndex = (pos + len - 1) >> chunkSizeLog2 + val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt while (endChunkIndex >= chunks.length) { chunks += new Array[Byte](chunkSize) } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toWrite = math.min(len - written, chunkSize - posInChunk) + val toWrite: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) written += toWrite chunkIndex += 1 @@ -119,19 +121,19 @@ private[spark] class ChainedBuffer(chunkSize: Int) { /** * Total size of buffer that can be written to without allocating additional memory. */ - def capacity: Int = chunks.size * chunkSize + def capacity: Long = chunks.size.toLong * chunkSize /** * Size of the logical buffer. */ - def size: Int = _size + def size: Long = _size } /** * Output stream that writes to a ChainedBuffer. */ private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos = 0 + private var pos: Long = 0 override def write(b: Int): Unit = { throw new UnsupportedOperationException() diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ac9ea63936..554d88206e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -41,6 +41,13 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ * * Currently, only sorting by partition is supported. * + * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across + * two integers: + * + * +-------------+------------+------------+-------------+ + * | keyStart | keyValLen | partitionId | + * +-------------+------------+------------+-------------+ + * * @param metaInitialRecords The initial number of entries in the metadata buffer. * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. * @param serializerInstance the serializer used for serializing inserted records. @@ -68,19 +75,15 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } val keyStart = kvBuffer.size - if (keyStart < 0) { - throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes") - } kvSerializationStream.writeKey[Any](key) - kvSerializationStream.flush() - val valueStart = kvBuffer.size kvSerializationStream.writeValue[Any](value) kvSerializationStream.flush() - val valueEnd = kvBuffer.size + val keyValLen = (kvBuffer.size - keyStart).toInt - metaBuffer.put(keyStart) - metaBuffer.put(valueStart) - metaBuffer.put(valueEnd) + // keyStart, a long, gets split across two ints + metaBuffer.put(keyStart.toInt) + metaBuffer.put((keyStart >> 32).toInt) + metaBuffer.put(keyValLen) metaBuffer.put(partition) } @@ -114,7 +117,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } } - override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity + override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { @@ -128,10 +131,10 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( var pos = 0 def writeNext(writer: BlockObjectWriter): Unit = { - val keyStart = metaBuffer.get(pos + KEY_START) - val valueEnd = metaBuffer.get(pos + VAL_END) + val keyStart = getKeyStartPos(metaBuffer, pos) + val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, valueEnd - keyStart) + kvBuffer.read(keyStart, writer, keyValLen) writer.recordWritten() } def nextPartition(): Int = metaBuffer.get(pos + PARTITION) @@ -163,9 +166,11 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) extends InputStream { + import PartitionedSerializedPairBuffer._ + private var metaBufferPos = 0 private var kvBufferPos = - if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0 + if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) @@ -173,13 +178,14 @@ private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: Chained if (metaBufferPos >= metaBuffer.position) { return -1 } - val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos + val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - + (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt val toRead = math.min(bytesRemainingInRecord, len) kvBuffer.read(kvBufferPos, bytes, offs, toRead) if (toRead == bytesRemainingInRecord) { metaBufferPos += RECORD_SIZE if (metaBufferPos < metaBuffer.position) { - kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START) + kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) } } else { kvBufferPos += toRead @@ -246,9 +252,14 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf } private[spark] object PartitionedSerializedPairBuffer { - val KEY_START = 0 - val VAL_START = 1 - val VAL_END = 2 + val KEY_START = 0 // keyStart, a long, gets split across two ints + val KEY_VAL_LEN = 2 val PARTITION = 3 - val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata + val RECORD_SIZE = PARTITION + 1 // num ints of metadata + + def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { + val lower32 = metaBuffer.get(metaBufferPos + KEY_START) + val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) + (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) + } } -- cgit v1.2.3