aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2015-04-30 23:14:14 -0700
committerPatrick Wendell <patrick@databricks.com>2015-04-30 23:14:14 -0700
commit0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92 (patch)
tree18cb693da7cf83292e1f2af7bdc8a16a1b033454 /core
parenta9fc50552ec96cd7817dfd19fc681b3368545ee3 (diff)
downloadspark-0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92.tar.gz
spark-0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92.tar.bz2
spark-0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92.zip
[SPARK-4550] In sort-based shuffle, store map outputs in serialized form
Refer to the JIRA for the design doc and some perf results. I wanted to call out some of the more possibly controversial changes up front: * Map outputs are only stored in serialized form when Kryo is in use. I'm still unsure whether Java-serialized objects can be relocated. At the very least, Java serialization writes out a stream header which causes problems with the current approach, so I decided to leave investigating this to future work. * The shuffle now explicitly operates on key-value pairs instead of any object. Data is written to shuffle files in alternating keys and values instead of key-value tuples. `BlockObjectWriter.write` now accepts a key argument and a value argument instead of any object. * The map output buffer can hold a max of Integer.MAX_VALUE bytes. Though this wouldn't be terribly difficult to change. * When spilling occurs, the objects that still in memory at merge time end up serialized and deserialized an extra time. Author: Sandy Ryza <sandy@cloudera.com> Closes #4450 from sryza/sandy-spark-4550 and squashes the following commits: 8c70dd9 [Sandy Ryza] Fix serialization 9c16fe6 [Sandy Ryza] Fix a couple tests and move getAutoReset to KryoSerializerInstance 6c54e06 [Sandy Ryza] Fix scalastyle d8462d8 [Sandy Ryza] SPARK-4550
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala144
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala144
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala (renamed from core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala)16
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala (renamed from core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala)58
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala254
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala113
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala143
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala189
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala149
21 files changed, 1208 insertions, 179 deletions
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 754832b8a4..b7bc087855 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(kryo, s)
}
+
+ /**
+ * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
+ * registrator explicitly turns auto-reset off.
+ */
+ def getAutoReset(): Boolean = {
+ val field = classOf[Kryo].getDeclaredField("autoReset")
+ field.setAccessible(true)
+ field.get(kryo).asInstanceOf[Boolean]
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ca6e971d22..c381672a4f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -101,7 +101,12 @@ abstract class SerializerInstance {
*/
@DeveloperApi
abstract class SerializationStream {
+ /** The most general-purpose method to write an object. */
def writeObject[T: ClassTag](t: T): SerializationStream
+ /** Writes the object representing the key of a key-value pair. */
+ def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
+ /** Writes the object representing the value of a key-value pair. */
+ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
def flush(): Unit
def close(): Unit
@@ -120,7 +125,12 @@ abstract class SerializationStream {
*/
@DeveloperApi
abstract class DeserializationStream {
+ /** The most general-purpose method to read an object. */
def readObject[T: ClassTag](): T
+ /** Reads the object representing the key of a key-value pair. */
+ def readKey[T: ClassTag](): T = readObject[T]()
+ /** Reads the object representing the value of a key-value pair. */
+ def readValue[T: ClassTag](): T = readObject[T]()
def close(): Unit
/**
@@ -141,4 +151,25 @@ abstract class DeserializationStream {
DeserializationStream.this.close()
}
}
+
+ /**
+ * Read the elements of this stream through an iterator over key-value pairs. This can only be
+ * called once, as reading each element will consume data from the input source.
+ */
+ def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
+ override protected def getNext() = {
+ try {
+ (readKey[Any](), readValue[Any]())
+ } catch {
+ case eof: EOFException => {
+ finished = true
+ null
+ }
+ }
+ }
+
+ override protected def close() {
+ DeserializationStream.this.close()
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 755f17d6aa..cd27c9e07a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V](
for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
- shuffle.writers(bucketId).write(elem)
+ shuffle.writers(bucketId).write(elem._1, elem._2)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 14833791f7..499dd97c06 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -33,7 +33,7 @@ import org.apache.spark.util.Utils
* This interface does not support concurrent writes. Also, once the writer has
* been opened, it cannot be reopened again.
*/
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
+private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
def open(): BlockObjectWriter
@@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
def revertPartialWritesAndClose()
/**
- * Writes an object.
+ * Writes a key-value pair.
*/
- def write(value: Any)
+ def write(key: Any, value: Any)
+
+ /**
+ * Notify the writer that a record worth of bytes has been written with writeBytes.
+ */
+ def recordWritten()
/**
* Returns the file segment of committed data that this Writer has written.
@@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(value: Any) {
+ override def write(key: Any, value: Any) {
+ if (!initialized) {
+ open()
+ }
+
+ objOut.writeKey(key)
+ objOut.writeValue(value)
+ numRecordsWritten += 1
+ writeMetrics.incShuffleRecordsWritten(1)
+
+ if (numRecordsWritten % 32 == 0) {
+ updateBytesWritten()
+ }
+ }
+
+ override def write(b: Int): Unit = throw new UnsupportedOperationException()
+
+ override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!initialized) {
open()
}
- objOut.writeObject(value)
+ bs.write(kvBytes, offs, len)
+ }
+
+ override def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter(
}
// For testing
- private[spark] def flush() {
+ private[spark] override def flush() {
objOut.flush()
bs.flush()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index f3379521d5..d0faab62c9 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,14 +17,12 @@
package org.apache.spark.storage
-import java.io.{InputStream, IOException}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Success, Try}
+import scala.util.{Failure, Try}
import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
@@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
- val iter = serializerInstance.deserializeStream(is).asIterator
+ val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
new file mode 100644
index 0000000000..a60bffe611
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
+ * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
+ * of memory and needing to copy the full contents. The disadvantage is that the contents don't
+ * occupy a contiguous segment of memory.
+ */
+private[spark] class ChainedBuffer(chunkSize: Int) {
+ private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
+ assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
+ s"ChainedBuffer chunk size $chunkSize must be a power of two")
+ private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
+ private var _size: Int = _
+
+ /**
+ * Feed bytes from this buffer into a BlockObjectWriter.
+ *
+ * @param pos Offset in the buffer to read from.
+ * @param os OutputStream to read into.
+ * @param len Number of bytes to read.
+ */
+ def read(pos: Int, os: OutputStream, len: Int): Unit = {
+ if (pos + len > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
+ }
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toRead = math.min(len - written, chunkSize - posInChunk)
+ os.write(chunks(chunkIndex), posInChunk, toRead)
+ written += toRead
+ chunkIndex += 1
+ posInChunk = 0
+ }
+ }
+
+ /**
+ * Read bytes from this buffer into a byte array.
+ *
+ * @param pos Offset in the buffer to read from.
+ * @param bytes Byte array to read into.
+ * @param offs Offset in the byte array to read to.
+ * @param len Number of bytes to read.
+ */
+ def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ if (pos + len > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Read of $len bytes at position $pos would go past size of buffer")
+ }
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toRead = math.min(len - written, chunkSize - posInChunk)
+ System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
+ written += toRead
+ chunkIndex += 1
+ posInChunk = 0
+ }
+ }
+
+ /**
+ * Write bytes from a byte array into this buffer.
+ *
+ * @param pos Offset in the buffer to write to.
+ * @param bytes Byte array to write from.
+ * @param offs Offset in the byte array to write from.
+ * @param len Number of bytes to write.
+ */
+ def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ if (pos > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Write at position $pos starts after end of buffer ${_size}")
+ }
+ // Grow if needed
+ val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
+ while (endChunkIndex >= chunks.length) {
+ chunks += new Array[Byte](chunkSize)
+ }
+
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toWrite = math.min(len - written, chunkSize - posInChunk)
+ System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
+ written += toWrite
+ chunkIndex += 1
+ posInChunk = 0
+ }
+
+ _size = math.max(_size, pos + len)
+ }
+
+ /**
+ * Total size of buffer that can be written to without allocating additional memory.
+ */
+ def capacity: Int = chunks.size * chunkSize
+
+ /**
+ * Size of the logical buffer.
+ */
+ def size: Int = _size
+}
+
+/**
+ * Output stream that writes to a ChainedBuffer.
+ */
+private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
+ private var pos = 0
+
+ override def write(b: Int): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ chainedBuffer.write(pos, bytes, offs, len)
+ pos += len
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index f912049563..b850973145 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C](
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
- writer.write(kv)
+ writer.write(kv._1, kv._2)
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
@@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private def readNextItem(): (K, C) = {
try {
- val item = deserializeStream.readObject().asInstanceOf[(K, C)]
+ val k = deserializeStream.readKey().asInstanceOf[K]
+ val c = deserializeStream.readValue().asInstanceOf[C]
+ val item = (k, c)
objectsRead += 1
if (objectsRead == serializerBatchSize) {
objectsRead = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4ed8a740f9..b7306cd551 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable
import com.google.common.io.ByteStreams
import org.apache.spark._
-import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage.{BlockObjectWriter, BlockId}
@@ -66,10 +66,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
*
* At a high level, this class works internally as follows:
*
- * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
- * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
- * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
- * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
+ * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
+ * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ * To avoid calling the partitioner multiple times with each key, we store the partition ID
+ * alongside each record.
*
* - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
* by partition ID and possibly second by key or by hash code of the key, if we want to do
@@ -96,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C](
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
- extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {
+ extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
@@ -126,11 +127,22 @@ private[spark] class ExternalSorter[K, V, C](
if (shouldPartition) partitioner.get.getPartition(key) else 0
}
+ private val metaInitialRecords = 256
+ private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+ private val useSerializedPairBuffer =
+ !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+ ser.isInstanceOf[KryoSerializer] &&
+ serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
+
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
- private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
- private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ private var map = new PartitionedAppendOnlyMap[K, C]
+ private var buffer = if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
// Total spilling statistics
private var _diskBytesSpilled = 0L
@@ -163,33 +175,6 @@ private[spark] class ExternalSorter[K, V, C](
}
})
- // A comparator for (Int, K) pairs that orders them by only their partition ID
- private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] {
- override def compare(a: (Int, K), b: (Int, K)): Int = {
- a._1 - b._1
- }
- }
-
- // A comparator that orders (Int, K) pairs by partition ID and then possibly by key
- private val partitionKeyComparator: Comparator[(Int, K)] = {
- if (ordering.isDefined || aggregator.isDefined) {
- // Sort by partition ID then key comparator
- new Comparator[(Int, K)] {
- override def compare(a: (Int, K), b: (Int, K)): Int = {
- val partitionDiff = a._1 - b._1
- if (partitionDiff != 0) {
- partitionDiff
- } else {
- keyComparator.compare(a._2, b._2)
- }
- }
- }
- } else {
- // Just sort it by partition ID
- partitionComparator
- }
- }
-
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
@@ -221,16 +206,18 @@ private[spark] class ExternalSorter[K, V, C](
} else if (bypassMergeSort) {
// SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
if (records.hasNext) {
- spillToPartitionFiles(records.map { kv =>
- ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
- })
+ spillToPartitionFiles(
+ WritablePartitionedIterator.fromIterator(records.map { kv =>
+ ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ })
+ )
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
- buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
@@ -248,11 +235,15 @@ private[spark] class ExternalSorter[K, V, C](
if (usingMap) {
if (maybeSpill(map, map.estimateSize())) {
- map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ map = new PartitionedAppendOnlyMap[K, C]
}
} else {
if (maybeSpill(buffer, buffer.estimateSize())) {
- buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ buffer = if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
}
}
}
@@ -260,7 +251,7 @@ private[spark] class ExternalSorter[K, V, C](
/**
* Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
*/
- override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
if (bypassMergeSort) {
spillToPartitionFiles(collection)
} else {
@@ -277,7 +268,7 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+ private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
assert(!bypassMergeSort)
// Because these files may be read during shuffle, their compression must be controlled by
@@ -308,14 +299,10 @@ private[spark] class ExternalSorter[K, V, C](
var success = false
try {
- val it = collection.destructiveSortedIterator(partitionKeyComparator)
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
- val elem = it.next()
- val partitionId = elem._1._1
- val key = elem._1._2
- val value = elem._2
- writer.write(key)
- writer.write(value)
+ val partitionId = it.nextPartition()
+ it.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
@@ -357,11 +344,11 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
- spillToPartitionFiles(collection.iterator)
+ private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
+ spillToPartitionFiles(collection.writablePartitionedIterator())
}
- private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
+ private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
assert(bypassMergeSort)
// Create our file writers if we haven't done so yet
@@ -385,11 +372,8 @@ private[spark] class ExternalSorter[K, V, C](
// No need to sort stuff, just write each element out
while (iterator.hasNext) {
- val elem = iterator.next()
- val partitionId = elem._1._1
- val key = elem._1._2
- val value = elem._2
- partitionWriters(partitionId).write((key, value))
+ val partitionId = iterator.nextPartition()
+ iterator.writeNext(partitionWriters(partitionId))
}
}
@@ -618,8 +602,8 @@ private[spark] class ExternalSorter[K, V, C](
if (finished || deserializeStream == null) {
return null
}
- val k = deserializeStream.readObject().asInstanceOf[K]
- val c = deserializeStream.readObject().asInstanceOf[C]
+ val k = deserializeStream.readKey().asInstanceOf[K]
+ val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
@@ -695,27 +679,27 @@ private[spark] class ExternalSorter[K, V, C](
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
- val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
if (spills.isEmpty && partitionWriters == null) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
- groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ groupByPartition(collection.partitionedDestructiveSortedIterator(None))
} else {
// We do need to sort by both partition ID and key
- groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+ groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
}
} else if (bypassMergeSort) {
// Read data from each partition file and merge it together with the data in memory;
// note that there's no ordering or aggregator in this case -- we just partition objects
- val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
collIter.map { case (partitionId, values) =>
(partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
}
} else {
// Merge spilled and in-memory data
- merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+ merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
}
}
@@ -762,15 +746,29 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.shuffleWriteMetrics.foreach(
_.incShuffleWriteTime(System.nanoTime - writeStartTime))
}
+ } else if (spills.isEmpty && partitionWriters == null) {
+ // Case where we only have in-memory data
+ val collection = if (aggregator.isDefined) map else buffer
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
+ while (it.hasNext) {
+ val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
+ context.taskMetrics.shuffleWriteMetrics.get)
+ val partitionId = it.nextPartition()
+ while (it.hasNext && it.nextPartition() == partitionId) {
+ it.writeNext(writer)
+ }
+ writer.commitAndClose()
+ val segment = writer.fileSegment()
+ lengths(partitionId) = segment.length
+ }
} else {
- // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
- // partition and just write everything directly.
+ // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
for (elem <- elements) {
- writer.write(elem)
+ writer.write(elem._1, elem._2)
}
writer.commitAndClose()
val segment = writer.fileSegment()
@@ -799,7 +797,7 @@ private[spark] class ExternalSorter[K, V, C](
if (writer.isOpen) {
writer.commitAndClose()
}
- blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+ new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
}
def stop(): Unit = {
@@ -829,6 +827,14 @@ private[spark] class ExternalSorter[K, V, C](
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
+ private def comparator: Option[Comparator[K]] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ Some(keyComparator)
+ } else {
+ None
+ }
+ }
+
/**
* An iterator that reads only the elements for a given partition ID from an underlying buffered
* stream, assuming this partition is the next one to be read. Used to make it easier to return
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
index faa4e2b12d..d75959f480 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
@@ -17,18 +17,8 @@
package org.apache.spark.util.collection
-import java.util.Comparator
+private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
+ def hasNext: Boolean = iter.hasNext
-/**
- * A common interface for our size-tracking collections of key-value pairs, which are used in
- * external operations. These all support estimating the size and obtaining a memory-efficient
- * sorted iterator.
- */
-// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
-private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
- /** Estimate the collection's current memory usage in bytes. */
- def estimateSize(): Long
-
- /** Iterate through the data in a given key order. This may destroy the underlying collection. */
- def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
+ def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
new file mode 100644
index 0000000000..e2e2f1faae
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
+/**
+ * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples
+ * of (partition ID, K)
+ */
+private[spark] class PartitionedAppendOnlyMap[K, V]
+ extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {
+
+ def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+ destructiveSortedIterator(comparator)
+ }
+
+ def writablePartitionedIterator(): WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(super.iterator)
+ }
+
+ def insert(partition: Int, key: K, value: V): Unit = {
+ update((partition, key), value)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index 9e9c16c5a2..e8332e1a87 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -19,11 +19,15 @@ package org.apache.spark.util.collection
import java.util.Comparator
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
/**
- * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track
+ * of its estimated size in bytes.
*/
-private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
- extends SizeTracker with SizeTrackingPairCollection[K, V]
+private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
+ extends WritablePartitionedPairCollection[K, V] with SizeTracker
{
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")
@@ -35,35 +39,16 @@ private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
private var data = new Array[AnyRef](2 * initialCapacity)
/** Add an element into the buffer */
- def insert(key: K, value: V): Unit = {
+ def insert(partition: Int, key: K, value: V): Unit = {
if (curSize == capacity) {
growArray()
}
- data(2 * curSize) = key.asInstanceOf[AnyRef]
+ data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
curSize += 1
afterUpdate()
}
- /** Total number of elements in buffer */
- override def size: Int = curSize
-
- /** Iterate over the elements of the buffer */
- override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
- var pos = 0
-
- override def hasNext: Boolean = pos < curSize
-
- override def next(): (K, V) = {
- if (!hasNext) {
- throw new NoSuchElementException
- }
- val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
- pos += 1
- pair
- }
- }
-
/** Double the size of the array because we've reached capacity */
private def growArray(): Unit = {
if (capacity == (1 << 29)) {
@@ -79,8 +64,29 @@ private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
}
/** Iterate through the data in a given order. For this class this is not really destructive. */
- override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
- new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
+ override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+ new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
iterator
}
+
+ override def writablePartitionedIterator(): WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(iterator)
+ }
+
+ private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
+ var pos = 0
+
+ override def hasNext: Boolean = pos < curSize
+
+ override def next(): ((Int, K), V) = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V])
+ pos += 1
+ pair
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
new file mode 100644
index 0000000000..b5ca0c62a0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.InputStream
+import java.nio.IntBuffer
+import java.util.Comparator
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
+
+/**
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
+ * its records upon insert and stores them as raw bytes.
+ *
+ * We use two data-structures to store the contents. The serialized records are stored in a
+ * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
+ * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
+ * record. Each entry in the metadata buffer takes up a fixed amount of space.
+ *
+ * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
+ * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
+ * happen without following any pointers, which should minimize cache misses.
+ *
+ * Currently, only sorting by partition is supported.
+ *
+ * @param metaInitialRecords The initial number of entries in the metadata buffer.
+ * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
+ * @param serializerInstance the serializer used for serializing inserted records.
+ */
+private[spark] class PartitionedSerializedPairBuffer[K, V](
+ metaInitialRecords: Int,
+ kvBlockSize: Int,
+ serializerInstance: SerializerInstance)
+ extends WritablePartitionedPairCollection[K, V] with SizeTracker {
+
+ if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
+ throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
+ " Java-serialized objects.")
+ }
+
+ private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
+
+ private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
+ private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
+ private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
+
+ def insert(partition: Int, key: K, value: V): Unit = {
+ if (metaBuffer.position == metaBuffer.capacity) {
+ growMetaBuffer()
+ }
+
+ val keyStart = kvBuffer.size
+ if (keyStart < 0) {
+ throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes")
+ }
+ kvSerializationStream.writeObject[Any](key)
+ kvSerializationStream.flush()
+ val valueStart = kvBuffer.size
+ kvSerializationStream.writeObject[Any](value)
+ kvSerializationStream.flush()
+ val valueEnd = kvBuffer.size
+
+ metaBuffer.put(keyStart)
+ metaBuffer.put(valueStart)
+ metaBuffer.put(valueEnd)
+ metaBuffer.put(partition)
+ }
+
+ /** Double the size of the array because we've reached capacity */
+ private def growMetaBuffer(): Unit = {
+ if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) {
+ // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+ throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes")
+ }
+ val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2)
+ newMetaBuffer.put(metaBuffer.array)
+ metaBuffer = newMetaBuffer
+ }
+
+ /** Iterate through the data in a given order. For this class this is not really destructive. */
+ override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ sort(keyComparator)
+ val is = orderedInputStream
+ val deserStream = serializerInstance.deserializeStream(is)
+ new Iterator[((Int, K), V)] {
+ var metaBufferPos = 0
+ def hasNext: Boolean = metaBufferPos < metaBuffer.position
+ def next(): ((Int, K), V) = {
+ val key = deserStream.readKey[Any]().asInstanceOf[K]
+ val value = deserStream.readValue[Any]().asInstanceOf[V]
+ val partition = metaBuffer.get(metaBufferPos + PARTITION)
+ metaBufferPos += RECORD_SIZE
+ ((partition, key), value)
+ }
+ }
+ }
+
+ override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity
+
+ override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+ : WritablePartitionedIterator = {
+ sort(keyComparator)
+ writablePartitionedIterator
+ }
+
+ override def writablePartitionedIterator(): WritablePartitionedIterator = {
+ new WritablePartitionedIterator {
+ // current position in the meta buffer in ints
+ var pos = 0
+
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ val keyStart = metaBuffer.get(pos + KEY_START)
+ val valueEnd = metaBuffer.get(pos + VAL_END)
+ pos += RECORD_SIZE
+ kvBuffer.read(keyStart, writer, valueEnd - keyStart)
+ writer.recordWritten()
+ }
+ def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
+ def hasNext(): Boolean = pos < metaBuffer.position
+ }
+ }
+
+ // Visible for testing
+ def orderedInputStream: OrderedInputStream = {
+ new OrderedInputStream(metaBuffer, kvBuffer)
+ }
+
+ private def sort(keyComparator: Option[Comparator[K]]): Unit = {
+ val comparator = if (keyComparator.isEmpty) {
+ new Comparator[Int]() {
+ def compare(partition1: Int, partition2: Int): Int = {
+ partition1 - partition2
+ }
+ }
+ } else {
+ throw new UnsupportedOperationException()
+ }
+
+ val sorter = new Sorter(new SerializedSortDataFormat)
+ sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
+ }
+}
+
+private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
+ extends InputStream {
+
+ private var metaBufferPos = 0
+ private var kvBufferPos =
+ if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0
+
+ override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
+
+ override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
+ if (metaBufferPos >= metaBuffer.position) {
+ return -1
+ }
+ val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos
+ val toRead = math.min(bytesRemainingInRecord, len)
+ kvBuffer.read(kvBufferPos, bytes, offs, toRead)
+ if (toRead == bytesRemainingInRecord) {
+ metaBufferPos += RECORD_SIZE
+ if (metaBufferPos < metaBuffer.position) {
+ kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START)
+ }
+ } else {
+ kvBufferPos += toRead
+ }
+ toRead
+ }
+
+ override def read(): Int = {
+ throw new UnsupportedOperationException()
+ }
+}
+
+private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
+
+ private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
+
+ /** Return the sort key for the element at the given index. */
+ override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
+ metaBuffer.get(pos * RECORD_SIZE + PARTITION)
+ }
+
+ /** Swap two elements. */
+ override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
+ val iOff = pos0 * RECORD_SIZE
+ val jOff = pos1 * RECORD_SIZE
+ System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
+ System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
+ System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
+ }
+
+ /** Copy a single element from src(srcPos) to dst(dstPos). */
+ override def copyElement(
+ src: IntBuffer,
+ srcPos: Int,
+ dst: IntBuffer,
+ dstPos: Int): Unit = {
+ val srcOff = srcPos * RECORD_SIZE
+ val dstOff = dstPos * RECORD_SIZE
+ System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
+ }
+
+ /**
+ * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
+ * Overlapping ranges are allowed.
+ */
+ override def copyRange(
+ src: IntBuffer,
+ srcPos: Int,
+ dst: IntBuffer,
+ dstPos: Int,
+ length: Int): Unit = {
+ val srcOff = srcPos * RECORD_SIZE
+ val dstOff = dstPos * RECORD_SIZE
+ System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
+ }
+
+ /**
+ * Allocates a Buffer that can hold up to 'length' elements.
+ * All elements of the buffer should be considered invalid until data is explicitly copied in.
+ */
+ override def allocate(length: Int): IntBuffer = {
+ IntBuffer.allocate(length * RECORD_SIZE)
+ }
+}
+
+private[spark] object PartitionedSerializedPairBuffer {
+ val KEY_START = 0
+ val VAL_START = 1
+ val VAL_END = 2
+ val PARTITION = 3
+ val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index eb4de41386..722f78bd15 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -21,7 +21,7 @@ package org.apache.spark.util.collection
* An append-only map that keeps track of its estimated size in bytes.
*/
private[spark] class SizeTrackingAppendOnlyMap[K, V]
- extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+ extends AppendOnlyMap[K, V] with SizeTracker
{
override def update(key: K, value: V): Unit = {
super.update(key, value)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
new file mode 100644
index 0000000000..f26d1618c9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.storage.BlockObjectWriter
+
+/**
+ * A common interface for size-tracking collections of key-value pairs that
+ * - Have an associated partition for each key-value pair.
+ * - Support a memory-efficient sorted iterator
+ * - Support a WritablePartitionedIterator for writing the contents directly as bytes.
+ */
+private[spark] trait WritablePartitionedPairCollection[K, V] {
+ /**
+ * Insert a key-value pair with a partition into the collection
+ */
+ def insert(partition: Int, key: K, value: V): Unit
+
+ /**
+ * Iterate through the data in order of partition ID and then the given comparator. This may
+ * destroy the underlying collection.
+ */
+ def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)]
+
+ /**
+ * Iterate through the data and write out the elements instead of returning them. Records are
+ * returned in order of their partition ID and then the given comparator.
+ * This may destroy the underlying collection.
+ */
+ def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+ : WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
+ }
+
+ /**
+ * Iterate through the data and write out the elements instead of returning them.
+ */
+ def writablePartitionedIterator(): WritablePartitionedIterator
+}
+
+private[spark] object WritablePartitionedPairCollection {
+ /**
+ * A comparator for (Int, K) pairs that orders them by only their partition ID.
+ */
+ def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+
+ /**
+ * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
+ */
+ def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ val partitionDiff = a._1 - b._1
+ if (partitionDiff != 0) {
+ partitionDiff
+ } else {
+ keyComparator.compare(a._2, b._2)
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * has an associated partition.
+ */
+private[spark] trait WritablePartitionedIterator {
+ def writeNext(writer: BlockObjectWriter): Unit
+
+ def hasNext(): Boolean
+
+ def nextPartition(): Int
+}
+
+private[spark] object WritablePartitionedIterator {
+ def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
+ new WritablePartitionedIterator {
+ var cur = if (it.hasNext) it.next() else null
+
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ writer.write(cur._1._2, cur._2)
+ cur = if (it.hasNext) it.next() else null
+ }
+
+ def hasNext(): Boolean = cur != null
+
+ def nextPartition(): Int = cur._1._1
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 1b13559e77..778a7eee73 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -280,6 +280,15 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](ser.serialize(largeObject))
assert(thrown.getMessage.contains(kryoBufferMaxProperty))
}
+
+ test("getAutoReset") {
+ val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
+ assert(ser.getAutoReset)
+ val conf = new SparkConf().set("spark.kryo.registrator",
+ classOf[RegistratorWithoutAutoReset].getName)
+ val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
+ assert(!ser2.getAutoReset)
+ }
}
@@ -313,4 +322,10 @@ object KryoTest {
k.register(classOf[java.util.HashMap[_, _]])
}
}
+
+ class RegistratorWithoutAutoReset extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ k.setAutoReset(false)
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
index 963264cef3..86fcf44728 100644
--- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -24,7 +24,7 @@ import scala.reflect.ClassTag
/**
- * A serializer implementation that always return a single element in a deserialization stream.
+ * A serializer implementation that always returns two elements in a deserialization stream.
*/
class TestSerializer extends Serializer {
override def newInstance(): TestSerializerInstance = new TestSerializerInstance
@@ -51,7 +51,7 @@ class TestDeserializationStream extends DeserializationStream {
override def readObject[T: ClassTag](): T = {
count += 1
- if (count == 2) {
+ if (count == 3) {
throw new EOFException
}
new Object().asInstanceOf[T]
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index 7d76435cd7..84384bb489 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -59,8 +59,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf),
new ShuffleWriteMetrics)
for (writer <- shuffle1.writers) {
- writer.write("test1")
- writer.write("test2")
+ writer.write("test1", "value")
+ writer.write("test2", "value")
}
for (writer <- shuffle1.writers) {
writer.commitAndClose()
@@ -73,8 +73,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
new ShuffleWriteMetrics)
for (writer <- shuffle2.writers) {
- writer.write("test3")
- writer.write("test4")
+ writer.write("test3", "value")
+ writer.write("test4", "vlue")
}
for (writer <- shuffle2.writers) {
writer.commitAndClose()
@@ -91,8 +91,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
new ShuffleWriteMetrics)
for (writer <- shuffle3.writers) {
- writer.write("test3")
- writer.write("test4")
+ writer.write("test3", "value")
+ writer.write("test4", "value")
}
for (writer <- shuffle3.writers) {
writer.commitAndClose()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index 003a728cb8..43ef469c1f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -32,7 +32,7 @@ class BlockObjectWriterSuite extends FunSuite {
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- writer.write(Long.box(20))
+ writer.write(Long.box(20), Long.box(30))
// Record metrics update on every write
assert(writeMetrics.shuffleRecordsWritten === 1)
// Metrics don't update on every write
@@ -40,7 +40,7 @@ class BlockObjectWriterSuite extends FunSuite {
// After 32 writes, metrics should update
for (i <- 0 until 32) {
writer.flush()
- writer.write(Long.box(i))
+ writer.write(Long.box(i), Long.box(i))
}
assert(writeMetrics.shuffleBytesWritten > 0)
assert(writeMetrics.shuffleRecordsWritten === 33)
@@ -54,7 +54,7 @@ class BlockObjectWriterSuite extends FunSuite {
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- writer.write(Long.box(20))
+ writer.write(Long.box(20), Long.box(30))
// Record metrics update on every write
assert(writeMetrics.shuffleRecordsWritten === 1)
// Metrics don't update on every write
@@ -62,7 +62,7 @@ class BlockObjectWriterSuite extends FunSuite {
// After 32 writes, metrics should update
for (i <- 0 until 32) {
writer.flush()
- writer.write(Long.box(i))
+ writer.write(Long.box(i), Long.box(i))
}
assert(writeMetrics.shuffleBytesWritten > 0)
assert(writeMetrics.shuffleRecordsWritten === 33)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
new file mode 100644
index 0000000000..c0c38cd4ac
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers._
+
+class ChainedBufferSuite extends FunSuite {
+ test("write and read at start") {
+ // write from start of source array
+ val buffer = new ChainedBuffer(8)
+ buffer.capacity should be (0)
+ verifyWriteAndRead(buffer, 0, 0, 0, 4)
+ buffer.capacity should be (8)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 0, 5, 0, 4)
+ buffer.capacity should be (8)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 0, 0, 5, 4)
+ buffer.capacity should be (8)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 0, 0, 0, 8)
+ buffer.capacity should be (8)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 0, 0, 0, 12)
+ buffer.capacity should be (16)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 0, 0, 0, 28)
+ buffer.capacity should be (32)
+ }
+
+ test("write and read at middle") {
+ val buffer = new ChainedBuffer(8)
+
+ // fill to a middle point
+ verifyWriteAndRead(buffer, 0, 0, 0, 3)
+
+ // write from start of source array
+ verifyWriteAndRead(buffer, 3, 0, 0, 4)
+ buffer.capacity should be (8)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 3, 5, 0, 4)
+ buffer.capacity should be (8)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 3, 0, 5, 4)
+ buffer.capacity should be (8)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 3, 0, 0, 5)
+ buffer.capacity should be (8)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 3, 0, 0, 12)
+ buffer.capacity should be (16)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 3, 0, 0, 28)
+ buffer.capacity should be (32)
+ }
+
+ test("write and read at later buffer") {
+ val buffer = new ChainedBuffer(8)
+
+ // fill to a middle point
+ verifyWriteAndRead(buffer, 0, 0, 0, 11)
+
+ // write from start of source array
+ verifyWriteAndRead(buffer, 11, 0, 0, 4)
+ buffer.capacity should be (16)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 11, 5, 0, 4)
+ buffer.capacity should be (16)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 11, 0, 5, 4)
+ buffer.capacity should be (16)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 11, 0, 0, 5)
+ buffer.capacity should be (16)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 11, 0, 0, 12)
+ buffer.capacity should be (24)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 11, 0, 0, 28)
+ buffer.capacity should be (40)
+ }
+
+
+ // Used to make sure we're writing different bytes each time
+ var rangeStart = 0
+
+ /**
+ * @param buffer The buffer to write to and read from.
+ * @param offsetInBuffer The offset to write to in the buffer.
+ * @param offsetInSource The offset in the array that the bytes are written from.
+ * @param offsetInTarget The offset in the array to read the bytes into.
+ * @param length The number of bytes to read and write
+ */
+ def verifyWriteAndRead(
+ buffer: ChainedBuffer,
+ offsetInBuffer: Int,
+ offsetInSource: Int,
+ offsetInTarget: Int,
+ length: Int): Unit = {
+ val source = new Array[Byte](offsetInSource + length)
+ (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
+ buffer.write(offsetInBuffer, source, offsetInSource, length)
+ val target = new Array[Byte](offsetInTarget + length)
+ buffer.read(offsetInBuffer, target, offsetInTarget, length)
+ ByteBuffer.wrap(source, offsetInSource, length) should be
+ (ByteBuffer.wrap(target, offsetInTarget, length))
+
+ rangeStart += 100
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index de26aa351b..20fd22b78e 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -19,19 +19,24 @@ package org.apache.spark.util.collection
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.{PrivateMethodTester, FunSuite}
-
-import org.apache.spark._
+import org.scalatest.{FunSuite, PrivateMethodTester}
import scala.util.Random
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+
class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester {
- private def createSparkConf(loadDefaults: Boolean): SparkConf = {
+ private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
val conf = new SparkConf(loadDefaults)
- // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
- // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
- conf.set("spark.serializer.objectStreamReset", "1")
- conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+ if (kryo) {
+ conf.set("spark.serializer", classOf[KryoSerializer].getName)
+ } else {
+ // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+ // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", classOf[JavaSerializer].getName)
+ }
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
conf
@@ -47,8 +52,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
}
- test("empty data stream") {
- val conf = new SparkConf(false)
+ test("empty data stream with kryo ser") {
+ emptyDataStream(createSparkConf(false, true))
+ }
+
+ test("empty data stream with java ser") {
+ emptyDataStream(createSparkConf(false, false))
+ }
+
+ def emptyDataStream(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -81,8 +93,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
sorter4.stop()
}
- test("few elements per partition") {
- val conf = createSparkConf(false)
+ test("few elements per partition with kryo ser") {
+ fewElementsPerPartition(createSparkConf(false, true))
+ }
+
+ test("few elements per partition with java ser") {
+ fewElementsPerPartition(createSparkConf(false, false))
+ }
+
+ def fewElementsPerPartition(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -123,8 +142,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
sorter4.stop()
}
- test("empty partitions with spilling") {
- val conf = createSparkConf(false)
+ test("empty partitions with spilling with kryo ser") {
+ emptyPartitionsWithSpilling(createSparkConf(false, true))
+ }
+
+ test("empty partitions with spilling with java ser") {
+ emptyPartitionsWithSpilling(createSparkConf(false, false))
+ }
+
+ def emptyPartitionsWithSpilling(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
@@ -149,8 +175,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
sorter.stop()
}
- test("empty partitions with spilling, bypass merge-sort") {
- val conf = createSparkConf(false)
+ test("empty partitions with spilling, bypass merge-sort with kryo ser") {
+ emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
+ }
+
+ test("empty partitions with spilling, bypass merge-sort with java ser") {
+ emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
+ }
+
+ def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
@@ -174,8 +207,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
sorter.stop()
}
- test("spilling in local cluster") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ test("spilling in local cluster with kryo ser") {
+ // Load defaults, otherwise SPARK_HOME is not found
+ testSpillingInLocalCluster(createSparkConf(true, true))
+ }
+
+ test("spilling in local cluster with java ser") {
+ // Load defaults, otherwise SPARK_HOME is not found
+ testSpillingInLocalCluster(createSparkConf(true, false))
+ }
+
+ def testSpillingInLocalCluster(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -245,8 +287,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
}
- test("spilling in local cluster with many reduce tasks") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ test("spilling in local cluster with many reduce tasks with kryo ser") {
+ spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true))
+ }
+
+ test("spilling in local cluster with many reduce tasks with java ser") {
+ spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false))
+ }
+
+ def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
@@ -317,7 +366,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in sorter") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -344,7 +393,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in sorter, bypass merge-sort") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -367,7 +416,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in sorter if there are errors") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -392,7 +441,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -414,7 +463,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in shuffle") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(false, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -429,7 +478,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("cleanup of intermediate files in shuffle with errors") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(false, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -450,8 +499,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(diskBlockManager.getAllFiles().length === 2)
}
- test("no partial aggregation or sorting") {
- val conf = createSparkConf(false)
+ test("no partial aggregation or sorting with kryo ser") {
+ noPartialAggregationOrSorting(createSparkConf(false, true))
+ }
+
+ test("no partial aggregation or sorting with java ser") {
+ noPartialAggregationOrSorting(createSparkConf(false, false))
+ }
+
+ def noPartialAggregationOrSorting(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -465,8 +521,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(results === expected)
}
- test("partial aggregation without spill") {
- val conf = createSparkConf(false)
+ test("partial aggregation without spill with kryo ser") {
+ partialAggregationWithoutSpill(createSparkConf(false, true))
+ }
+
+ test("partial aggregation without spill with java ser") {
+ partialAggregationWithoutSpill(createSparkConf(false, false))
+ }
+
+ def partialAggregationWithoutSpill(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -481,8 +544,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(results === expected)
}
- test("partial aggregation with spill, no ordering") {
- val conf = createSparkConf(false)
+ test("partial aggregation with spill, no ordering with kryo ser") {
+ partialAggregationWIthSpillNoOrdering(createSparkConf(false, true))
+ }
+
+ test("partial aggregation with spill, no ordering with java ser") {
+ partialAggregationWIthSpillNoOrdering(createSparkConf(false, false))
+ }
+
+ def partialAggregationWIthSpillNoOrdering(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -497,8 +567,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(results === expected)
}
- test("partial aggregation with spill, with ordering") {
- val conf = createSparkConf(false)
+ test("partial aggregation with spill, with ordering with kryo ser") {
+ partialAggregationWithSpillWithOrdering(createSparkConf(false, true))
+ }
+
+
+ test("partial aggregation with spill, with ordering with java ser") {
+ partialAggregationWithSpillWithOrdering(createSparkConf(false, false))
+ }
+
+ def partialAggregationWithSpillWithOrdering(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -517,8 +595,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(results === expected)
}
- test("sorting without aggregation, no spill") {
- val conf = createSparkConf(false)
+ test("sorting without aggregation, no spill with kryo ser") {
+ sortingWithoutAggregationNoSpill(createSparkConf(false, true))
+ }
+
+ test("sorting without aggregation, no spill with java ser") {
+ sortingWithoutAggregationNoSpill(createSparkConf(false, false))
+ }
+
+ def sortingWithoutAggregationNoSpill(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -534,8 +619,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assert(results === expected)
}
- test("sorting without aggregation, with spill") {
- val conf = createSparkConf(false)
+ test("sorting without aggregation, with spill with kryo ser") {
+ sortingWithoutAggregationWithSpill(createSparkConf(false, true))
+ }
+
+ test("sorting without aggregation, with spill with java ser") {
+ sortingWithoutAggregationWithSpill(createSparkConf(false, false))
+ }
+
+ def sortingWithoutAggregationWithSpill(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -552,7 +644,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("spilling with hash collisions") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -609,7 +701,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("spilling with many hash collisions") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.0001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -632,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("spilling with hash collisions using the Int.MaxValue key") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -656,7 +748,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("spilling with null keys and values") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -685,7 +777,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
}
test("conditions for bypassing merge-sort") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(false, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
@@ -718,8 +810,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
assertDidNotBypassMergeSort(sorter4)
}
- test("sort without breaking sorting contracts") {
- val conf = createSparkConf(true)
+ test("sort without breaking sorting contracts with kryo ser") {
+ sortWithoutBreakingSortingContracts(createSparkConf(true, true))
+ }
+
+ test("sort without breaking sorting contracts with java ser") {
+ sortWithoutBreakingSortingContracts(createSparkConf(true, false))
+ }
+
+ def sortWithoutBreakingSortingContracts(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.01")
conf.set("spark.shuffle.manager", "sort")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
new file mode 100644
index 0000000000..b5a2d9ef72
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+
+import com.google.common.io.ByteStreams
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+
+class PartitionedSerializedPairBufferSuite extends FunSuite {
+ test("OrderedInputStream single record") {
+ val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+
+ val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+ val struct = SomeStruct("something", 5)
+ buffer.insert(4, 10, struct)
+
+ val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
+
+ val baos = new ByteArrayOutputStream()
+ val stream = serializerInstance.serializeStream(baos)
+ stream.writeObject(10)
+ stream.writeObject(struct)
+ stream.close()
+
+ baos.toByteArray should be (bytes)
+ }
+
+ test("insert single record") {
+ val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+ val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+ val struct = SomeStruct("something", 5)
+ buffer.insert(4, 10, struct)
+ val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
+ elements.size should be (1)
+ elements.head should be (((4, 10), struct))
+ }
+
+ test("insert multiple records") {
+ val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+ val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+ val struct1 = SomeStruct("something1", 8)
+ buffer.insert(6, 1, struct1)
+ val struct2 = SomeStruct("something2", 9)
+ buffer.insert(4, 2, struct2)
+ val struct3 = SomeStruct("something3", 10)
+ buffer.insert(5, 3, struct3)
+
+ val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
+ elements.size should be (3)
+ elements(0) should be (((4, 2), struct2))
+ elements(1) should be (((5, 3), struct3))
+ elements(2) should be (((6, 1), struct1))
+ }
+
+ test("write single record") {
+ val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+ val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+ val struct = SomeStruct("something", 5)
+ buffer.insert(4, 10, struct)
+ val it = buffer.destructiveSortedWritablePartitionedIterator(None)
+ val writer = new SimpleBlockObjectWriter
+ assert(it.hasNext)
+ it.nextPartition should be (4)
+ it.writeNext(writer)
+ assert(!it.hasNext)
+
+ val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ stream.readObject[AnyRef]() should be (10)
+ stream.readObject[AnyRef]() should be (struct)
+ }
+
+ test("write multiple records") {
+ val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+ val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+ val struct1 = SomeStruct("something1", 8)
+ buffer.insert(6, 1, struct1)
+ val struct2 = SomeStruct("something2", 9)
+ buffer.insert(4, 2, struct2)
+ val struct3 = SomeStruct("something3", 10)
+ buffer.insert(5, 3, struct3)
+
+ val it = buffer.destructiveSortedWritablePartitionedIterator(None)
+ val writer = new SimpleBlockObjectWriter
+ assert(it.hasNext)
+ it.nextPartition should be (4)
+ it.writeNext(writer)
+ assert(it.hasNext)
+ it.nextPartition should be (5)
+ it.writeNext(writer)
+ assert(it.hasNext)
+ it.nextPartition should be (6)
+ it.writeNext(writer)
+ assert(!it.hasNext)
+
+ val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val iter = stream.asIterator
+ iter.next() should be (2)
+ iter.next() should be (struct2)
+ iter.next() should be (3)
+ iter.next() should be (struct3)
+ iter.next() should be (1)
+ iter.next() should be (struct1)
+ assert(!iter.hasNext)
+ }
+}
+
+case class SomeStruct(val str: String, val num: Int)
+
+class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
+ val baos = new ByteArrayOutputStream()
+
+ override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ baos.write(bytes, offs, len)
+ }
+
+ def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
+
+ override def open(): BlockObjectWriter = this
+ override def close(): Unit = { }
+ override def isOpen: Boolean = true
+ override def commitAndClose(): Unit = { }
+ override def revertPartialWritesAndClose(): Unit = { }
+ override def fileSegment(): FileSegment = null
+ override def write(key: Any, value: Any): Unit = { }
+ override def recordWritten(): Unit = { }
+ override def write(b: Int): Unit = { }
+}