From 4fde28c2063f673ec7f51d514ba62a73321960a1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:41:03 -0700 Subject: SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere). Author: Matei Zaharia Closes #1707 from mateiz/spark-2711 and squashes the following commits: debf75b [Matei Zaharia] Review comments 24f28f3 [Matei Zaharia] Small rename c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests 315e3a5 [Matei Zaharia] Some review comments b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 10 +- .../scala/org/apache/spark/executor/Executor.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 125 +++++++++ .../util/collection/ExternalAppendOnlyMap.scala | 48 +--- .../spark/util/collection/ExternalSorter.scala | 49 +--- .../spark/shuffle/ShuffleMemoryManagerSuite.scala | 294 +++++++++++++++++++++ 6 files changed, 450 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0bce531aab..dd8e4ac66d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -66,12 +66,9 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations - // All accesses should be manually synchronized - val shuffleMemoryMap = mutable.HashMap[Long, Long]() - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -252,6 +249,8 @@ object SparkEnv extends Logging { val shuffleManager = instantiateClass[ShuffleManager]( "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -273,6 +272,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + shuffleMemoryManager, conf) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1bb1b4aae9..c2b9c660dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -276,10 +276,7 @@ private[spark] class Executor( } } finally { // Release memory used by this thread for shuffles - val shuffleMemoryMap = env.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap.remove(Thread.currentThread().getId) - } + env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala new file mode 100644 index 0000000000..ee91a368b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkException, SparkConf} + +/** + * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory + * from this pool and release it as it spills data out. When a task ends, all its memory will be + * released by the Executor. + * + * This class tries to ensure that each thread gets a reasonable share of memory, instead of some + * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * this set changes. This is all done by synchronizing access on "this" to mutate state and using + * wait() and notifyAll() to signal changes. + */ +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { + private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + + /** + * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * obtained, or 0 if none can be allocated. This call may block until there is enough free memory + * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active threads) before it is forced to spill. This can + * happen if the number of threads increases but an older thread had a lot of memory already. + */ + def tryToAcquire(numBytes: Long): Long = synchronized { + val threadId = Thread.currentThread().getId + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this thread to the threadMemory map just so we can keep an accurate count of the number + // of active threads, to let other threads ramp down their memory in calls to tryToAcquire + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numThreads again + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // thread would have more than 1 / numActiveThreads of the memory) or we have enough free + // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + while (true) { + val numActiveThreads = threadMemory.keys.size + val curMem = threadMemory(threadId) + val freeMemory = maxMemory - threadMemory.values.sum + + // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads + val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem) + + if (curMem < maxMemory / (2 * numActiveThreads)) { + // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; + // if we can't give it this much now, wait for other threads to free up memory + // (this happens if older threads allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } else { + logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + wait() + } + } else { + // Only give it as much memory as is free, which might be none if it reached 1 / numThreads + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** Release numBytes bytes for the current thread. */ + def release(numBytes: Long): Unit = synchronized { + val threadId = Thread.currentThread().getId + val curMem = threadMemory.getOrElse(threadId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + } + threadMemory(threadId) -= numBytes + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } + + /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisThread(): Unit = synchronized { + val threadId = Thread.currentThread().getId + threadMemory.remove(threadId) + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } +} + +private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their sizes again. + */ + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1f7d2dc838..cc0423856c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows @@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && currentMap.estimateSize() >= myMemoryThreshold) { - val currentSize = currentMap.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // this map to grow and, if possible, allocate the required amount - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not synchronize spills - if (shouldSpill) { - spill(currentSize) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = currentMap.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory) // Will also release memory back to ShuffleMemoryManager } } currentMap.changeValue(curEntry._1, update) @@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } - myMemoryThreshold = 0 + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L elementsRead = 0 _memoryBytesSpilled += mapSize diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b04c50bd3e..101c83b264 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && collection.estimateSize() >= myMemoryThreshold) { - // TODO: This logic doesn't work if there are two external collections being used in the same - // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] - - val currentSize = collection.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // us to double our threshold - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not hold lock during spills - if (shouldSpill) { - spill(currentSize, usingMap) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = collection.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager } } } @@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C]( buffer = new SizeTrackingPairBuffer[(Int, K), C] } - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) myMemoryThreshold = 0 spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala new file mode 100644 index 0000000000..d31bc22ee7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.CountDownLatch + +class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { + /** Launch a thread with the given body block and return it. */ + private def startThread(name: String)(body: => Unit): Thread = { + val thread = new Thread("ShuffleMemorySuite " + name) { + override def run() { + body + } + } + thread.start() + thread + } + + test("single thread requesting memory") { + val manager = new ShuffleMemoryManager(1000L) + + assert(manager.tryToAcquire(100L) === 100L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(200L) === 100L) + assert(manager.tryToAcquire(100L) === 0L) + assert(manager.tryToAcquire(100L) === 0L) + + manager.release(500L) + assert(manager.tryToAcquire(300L) === 300L) + assert(manager.tryToAcquire(300L) === 200L) + + manager.releaseMemoryForThisThread() + assert(manager.tryToAcquire(1000L) === 1000L) + assert(manager.tryToAcquire(100L) === 0L) + } + + test("two threads requesting full memory") { + // Two threads request 500 bytes first, wait for each other to get it, and then request + // 500 more; we should immediately return 0 as both are now at 1 / N + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 500L) + assert(state.t2Result1 === 500L) + assert(state.t1Result2 === 0L) + assert(state.t2Result2 === 0L) + } + + + test("threads cannot grow past 1 / N") { + // Two threads request 250 bytes first, wait for each other to get it, and then request + // 500 more; we should only grant 250 bytes to each of them on this second request + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 250L) + assert(state.t2Result1 === 250L) + assert(state.t1Result2 === 250L) + assert(state.t2Result2 === 250L) + } + + test("threads can block to get at least 1 / 2N memory") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests + // by t2 will return false right away because it now has 1 / 2N of the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result = -1L + var t2Result2 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.release(250L) + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val result = manager.tryToAcquire(250L) + val endTime = System.currentTimeMillis() + state.synchronized { + state.t2Result = result + // A second call should return 0 because we're now already at 1 / 2N + state.t2Result2 = manager.tryToAcquire(100L) + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released 250 + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result === 250L, "t2 could not allocate memory") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + assert(state.t2Result2 === 0L, "t1 got extra memory the second time") + } + } + + test("releaseMemoryForThisThread") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases all its memory. t2 should now be able to grab all the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result1 = -1L + var t2Result2 = -1L + var t2Result3 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.releaseMemoryForThisThread() + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val r1 = manager.tryToAcquire(500L) + val endTime = System.currentTimeMillis() + val r2 = manager.tryToAcquire(500L) + val r3 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.t2Result2 = r2 + state.t2Result3 = r3 + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released all of it + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") + assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") + assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + } + } +} -- cgit v1.2.3