From 25164a89dd32eef58d9b6823ae259439f796e81a Mon Sep 17 00:00:00 2001 From: Jim Lim Date: Sun, 28 Sep 2014 19:04:24 -0700 Subject: SPARK-2761 refactor #maybeSpill into Spillable Moved `#maybeSpill` in ExternalSorter and EAOM into `Spillable`. Author: Jim Lim Closes #2416 from jimjh/SPARK-2761 and squashes the following commits: cf8be9a [Jim Lim] SPARK-2761 fix documentation, reorder code f94d522 [Jim Lim] SPARK-2761 refactor Spillable to simplify sig e75a24e [Jim Lim] SPARK-2761 use protected over protected[this] 7270e0d [Jim Lim] SPARK-2761 refactor #maybeSpill into Spillable --- .../util/collection/ExternalAppendOnlyMap.scala | 46 ++------- .../spark/util/collection/ExternalSorter.scala | 68 +++---------- .../apache/spark/util/collection/Spillable.scala | 111 +++++++++++++++++++++ 3 files changed, 133 insertions(+), 92 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/Spillable.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a015c1d26..0c088da46a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -66,23 +66,19 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) - extends Iterable[(K, C)] with Serializable with Logging { + extends Iterable[(K, C)] + with Serializable + with Logging + with Spillable[SizeTracker] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // Number of in-memory pairs inserted before tracking the map's shuffle memory usage - private val trackMemoryThreshold = 1000 - - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L + protected[this] var elementsRead = 0L /** * Size of object batches when reading/writing from serializers. @@ -95,11 +91,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - // How many times we have spilled so far - private var spillCount = 0 - // Number of bytes spilled in total - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 @@ -136,19 +128,8 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - currentMap.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = currentMap.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory) // Will also release memory back to ShuffleMemoryManager - } + if (maybeSpill(currentMap, currentMap.estimateSize())) { + currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) elementsRead += 1 @@ -171,11 +152,7 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - private def spill(mapSize: Long): Unit = { - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, @@ -231,18 +208,11 @@ class ExternalAppendOnlyMap[K, V, C]( } } - currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L - elementsRead = 0 - _memoryBytesSpilled += mapSize } - def memoryBytesSpilled: Long = _memoryBytesSpilled def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 782b979e2e..0a152cb97a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -79,14 +79,14 @@ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Option[Serializer] = None) extends Logging { + serializer: Option[Serializer] = None) + extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] { private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -115,22 +115,14 @@ private[spark] class ExternalSorter[K, V, C]( // Number of pairs read from input since last spill; note that we count them even if a value is // merged with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // What threshold of elementsRead we start estimating map size at. - private val trackMemoryThreshold = 1000 + protected[this] var elementsRead = 0L // Total spilling statistics - private var spillCount = 0 - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need // local aggregation and sorting, write numPartitions files directly and just concatenate them // at the end. This avoids doing serialization and deserialization twice to merge together the @@ -209,7 +201,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) - maybeSpill(usingMap = true) + maybeSpillCollection(usingMap = true) } } else { // Stick values into our buffer @@ -217,7 +209,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - maybeSpill(usingMap = false) + maybeSpillCollection(usingMap = false) } } } @@ -227,61 +219,31 @@ private[spark] class ExternalSorter[K, V, C]( * * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def maybeSpill(usingMap: Boolean): Unit = { + private def maybeSpillCollection(usingMap: Boolean): Unit = { if (!spillingEnabled) { return } - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - - // TODO: factor this out of both here and ExternalAppendOnlyMap - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - collection.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = collection.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager + if (usingMap) { + if (maybeSpill(map, map.estimateSize())) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } + } else { + if (maybeSpill(buffer, buffer.estimateSize())) { + buffer = new SizeTrackingPairBuffer[(Int, K), C] } } } /** * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - * - * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def spill(memorySize: Long, usingMap: Boolean): Unit = { - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - val memorySize = collection.estimateSize() - - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" - .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) - + override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { if (bypassMergeSort) { spillToPartitionFiles(collection) } else { spillToMergeableFile(collection) } - - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } - - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 - - _memoryBytesSpilled += memorySize } /** @@ -804,8 +766,6 @@ private[spark] class ExternalSorter[K, V, C]( } } - def memoryBytesSpilled: Long = _memoryBytesSpilled - def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala new file mode 100644 index 0000000000..d7dccd4af8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv + +/** + * Spills contents of an in-memory collection to disk when the memory threshold + * has been exceeded. + */ +private[spark] trait Spillable[C] { + + this: Logging => + + /** + * Spills the current in-memory collection to disk, and releases the memory. + * + * @param collection collection to spill to disk + */ + protected def spill(collection: C): Unit + + // Number of elements read from input since last spill + protected var elementsRead: Long + + // Memory manager that can be used to acquire/release memory + private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + // What threshold of elementsRead we start estimating collection size at + private[this] val trackMemoryThreshold = 1000 + + // How much of the shared memory pool this collection has claimed + private[this] var myMemoryThreshold = 0L + + // Number of bytes spilled in total + private[this] var _memoryBytesSpilled = 0L + + // Number of spills + private[this] var _spillCount = 0 + + /** + * Spills the current in-memory collection to disk if needed. Attempts to acquire more + * memory before spilling. + * + * @param collection collection to spill to disk + * @param currentMemory estimated size of the collection in bytes + * @return true if `collection` was spilled to disk; false otherwise + */ + protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + currentMemory >= myMemoryThreshold) { + // Claim up to double our current memory from the shuffle memory pool + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + _spillCount += 1 + logSpillage(currentMemory) + + spill(collection) + + // Keep track of spills, and release memory + _memoryBytesSpilled += currentMemory + releaseMemoryForThisThread() + return true + } + } + false + } + + /** + * @return number of bytes spilled in total + */ + def memoryBytesSpilled: Long = _memoryBytesSpilled + + /** + * Release our memory back to the shuffle pool so that other threads can grab it. + */ + private def releaseMemoryForThisThread(): Unit = { + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L + } + + /** + * Prints a standard log message detailing spillage. + * + * @param size number of bytes spilled + */ + @inline private def logSpillage(size: Long) { + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else "")) + } +} -- cgit v1.2.3