aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-08-04 23:41:03 -0700
committerMatei Zaharia <matei@databricks.com>2014-08-04 23:41:03 -0700
commit4fde28c2063f673ec7f51d514ba62a73321960a1 (patch)
tree1125522786941e5a71ea2526df435bc9f6706405
parent066765d60d21b6b9943862b788e4a4bd07396e6c (diff)
downloadspark-4fde28c2063f673ec7f51d514ba62a73321960a1.tar.gz
spark-4fde28c2063f673ec7f51d514ba62a73321960a1.tar.bz2
spark-4fde28c2063f673ec7f51d514ba62a73321960a1.zip
SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections
This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere). Author: Matei Zaharia <matei@databricks.com> Closes #1707 from mateiz/spark-2711 and squashes the following commits: debf75b [Matei Zaharia] Review comments 24f28f3 [Matei Zaharia] Small rename c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests 315e3a5 [Matei Zaharia] Some review comments b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala125
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala49
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala294
6 files changed, 450 insertions, 81 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0bce531aab..dd8e4ac66d 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -66,12 +66,9 @@ class SparkEnv (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
+ val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
- // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations
- // All accesses should be manually synchronized
- val shuffleMemoryMap = mutable.HashMap[Long, Long]()
-
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -252,6 +249,8 @@ object SparkEnv extends Logging {
val shuffleManager = instantiateClass[ShuffleManager](
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+ val shuffleMemoryManager = new ShuffleMemoryManager(conf)
+
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -273,6 +272,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
+ shuffleMemoryManager,
conf)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 1bb1b4aae9..c2b9c660dd 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -276,10 +276,7 @@ private[spark] class Executor(
}
} finally {
// Release memory used by this thread for shuffles
- val shuffleMemoryMap = env.shuffleMemoryMap
- shuffleMemoryMap.synchronized {
- shuffleMemoryMap.remove(Thread.currentThread().getId)
- }
+ env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
runningTasks.remove(taskId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
new file mode 100644
index 0000000000..ee91a368b7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import scala.collection.mutable
+
+import org.apache.spark.{Logging, SparkException, SparkConf}
+
+/**
+ * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
+ * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
+ * from this pool and release it as it spills data out. When a task ends, all its memory will be
+ * released by the Executor.
+ *
+ * This class tries to ensure that each thread gets a reasonable share of memory, instead of some
+ * thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
+ * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
+ * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
+ * this set changes. This is all done by synchronizing access on "this" to mutate state and using
+ * wait() and notifyAll() to signal changes.
+ */
+private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
+ private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
+
+ def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
+
+ /**
+ * Try to acquire up to numBytes memory for the current thread, and return the number of bytes
+ * obtained, or 0 if none can be allocated. This call may block until there is enough free memory
+ * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
+ * total memory pool (where N is the # of active threads) before it is forced to spill. This can
+ * happen if the number of threads increases but an older thread had a lot of memory already.
+ */
+ def tryToAcquire(numBytes: Long): Long = synchronized {
+ val threadId = Thread.currentThread().getId
+ assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
+
+ // Add this thread to the threadMemory map just so we can keep an accurate count of the number
+ // of active threads, to let other threads ramp down their memory in calls to tryToAcquire
+ if (!threadMemory.contains(threadId)) {
+ threadMemory(threadId) = 0L
+ notifyAll() // Will later cause waiting threads to wake up and check numThreads again
+ }
+
+ // Keep looping until we're either sure that we don't want to grant this request (because this
+ // thread would have more than 1 / numActiveThreads of the memory) or we have enough free
+ // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
+ while (true) {
+ val numActiveThreads = threadMemory.keys.size
+ val curMem = threadMemory(threadId)
+ val freeMemory = maxMemory - threadMemory.values.sum
+
+ // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
+ val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
+
+ if (curMem < maxMemory / (2 * numActiveThreads)) {
+ // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
+ // if we can't give it this much now, wait for other threads to free up memory
+ // (this happens if older threads allocated lots of memory before N grew)
+ if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
+ val toGrant = math.min(maxToGrant, freeMemory)
+ threadMemory(threadId) += toGrant
+ return toGrant
+ } else {
+ logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
+ wait()
+ }
+ } else {
+ // Only give it as much memory as is free, which might be none if it reached 1 / numThreads
+ val toGrant = math.min(maxToGrant, freeMemory)
+ threadMemory(threadId) += toGrant
+ return toGrant
+ }
+ }
+ 0L // Never reached
+ }
+
+ /** Release numBytes bytes for the current thread. */
+ def release(numBytes: Long): Unit = synchronized {
+ val threadId = Thread.currentThread().getId
+ val curMem = threadMemory.getOrElse(threadId, 0L)
+ if (curMem < numBytes) {
+ throw new SparkException(
+ s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
+ }
+ threadMemory(threadId) -= numBytes
+ notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
+ }
+
+ /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
+ def releaseMemoryForThisThread(): Unit = synchronized {
+ val threadId = Thread.currentThread().getId
+ threadMemory.remove(threadId)
+ notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
+ }
+}
+
+private object ShuffleMemoryManager {
+ /**
+ * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
+ * of the memory pool and a safety factor since collections can sometimes grow bigger than
+ * the size we target before we estimate their sizes again.
+ */
+ def getMaxMemory(conf: SparkConf): Long = {
+ val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+ val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1f7d2dc838..cc0423856c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C](
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
private val diskBlockManager = blockManager.diskBlockManager
-
- // Collective memory threshold shared across all running tasks
- private val maxMemoryThreshold = {
- val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2)
- val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
- (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
- }
+ private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
// Number of pairs inserted since last spill; note that we count them even if a value is merged
// with a previous key in case we're doing something like groupBy where the result grows
@@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C](
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
currentMap.estimateSize() >= myMemoryThreshold)
{
- val currentSize = currentMap.estimateSize()
- var shouldSpill = false
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-
- // Atomically check whether there is sufficient memory in the global pool for
- // this map to grow and, if possible, allocate the required amount
- shuffleMemoryMap.synchronized {
- val threadId = Thread.currentThread().getId
- val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
- val availableMemory = maxMemoryThreshold -
- (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
-
- // Try to allocate at least 2x more memory, otherwise spill
- shouldSpill = availableMemory < currentSize * 2
- if (!shouldSpill) {
- shuffleMemoryMap(threadId) = currentSize * 2
- myMemoryThreshold = currentSize * 2
- }
- }
- // Do not synchronize spills
- if (shouldSpill) {
- spill(currentSize)
+ // Claim up to double our current memory from the shuffle memory pool
+ val currentMemory = currentMap.estimateSize()
+ val amountToRequest = 2 * currentMemory - myMemoryThreshold
+ val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+ myMemoryThreshold += granted
+ if (myMemoryThreshold <= currentMemory) {
+ // We were granted too little memory to grow further (either tryToAcquire returned 0,
+ // or we already had more memory than myMemoryThreshold); spill the current collection
+ spill(currentMemory) // Will also release memory back to ShuffleMemoryManager
}
}
currentMap.changeValue(curEntry._1, update)
@@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C](
currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
- // Reset the amount of shuffle memory used by this map in the global pool
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
- shuffleMemoryMap.synchronized {
- shuffleMemoryMap(Thread.currentThread().getId) = 0
- }
- myMemoryThreshold = 0
+ // Release our memory back to the shuffle pool so that other threads can grab it
+ shuffleMemoryManager.release(myMemoryThreshold)
+ myMemoryThreshold = 0L
elementsRead = 0
_memoryBytesSpilled += mapSize
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index b04c50bd3e..101c83b264 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
+ private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
@@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C](
private var _memoryBytesSpilled = 0L
private var _diskBytesSpilled = 0L
- // Collective memory threshold shared across all running tasks
- private val maxMemoryThreshold = {
- val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
- val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
- (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
- }
-
// How much of the shared memory pool this collection has claimed
private var myMemoryThreshold = 0L
@@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C](
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
collection.estimateSize() >= myMemoryThreshold)
{
- // TODO: This logic doesn't work if there are two external collections being used in the same
- // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
-
- val currentSize = collection.estimateSize()
- var shouldSpill = false
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-
- // Atomically check whether there is sufficient memory in the global pool for
- // us to double our threshold
- shuffleMemoryMap.synchronized {
- val threadId = Thread.currentThread().getId
- val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
- val availableMemory = maxMemoryThreshold -
- (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
-
- // Try to allocate at least 2x more memory, otherwise spill
- shouldSpill = availableMemory < currentSize * 2
- if (!shouldSpill) {
- shuffleMemoryMap(threadId) = currentSize * 2
- myMemoryThreshold = currentSize * 2
- }
- }
- // Do not hold lock during spills
- if (shouldSpill) {
- spill(currentSize, usingMap)
+ // Claim up to double our current memory from the shuffle memory pool
+ val currentMemory = collection.estimateSize()
+ val amountToRequest = 2 * currentMemory - myMemoryThreshold
+ val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+ myMemoryThreshold += granted
+ if (myMemoryThreshold <= currentMemory) {
+ // We were granted too little memory to grow further (either tryToAcquire returned 0,
+ // or we already had more memory than myMemoryThreshold); spill the current collection
+ spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager
}
}
}
@@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C](
buffer = new SizeTrackingPairBuffer[(Int, K), C]
}
- // Reset the amount of shuffle memory used by this map in the global pool
- val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
- shuffleMemoryMap.synchronized {
- shuffleMemoryMap(Thread.currentThread().getId) = 0
- }
+ // Release our memory back to the shuffle pool so that other threads can grab it
+ shuffleMemoryManager.release(myMemoryThreshold)
myMemoryThreshold = 0
spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
new file mode 100644
index 0000000000..d31bc22ee7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.CountDownLatch
+
+class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
+ /** Launch a thread with the given body block and return it. */
+ private def startThread(name: String)(body: => Unit): Thread = {
+ val thread = new Thread("ShuffleMemorySuite " + name) {
+ override def run() {
+ body
+ }
+ }
+ thread.start()
+ thread
+ }
+
+ test("single thread requesting memory") {
+ val manager = new ShuffleMemoryManager(1000L)
+
+ assert(manager.tryToAcquire(100L) === 100L)
+ assert(manager.tryToAcquire(400L) === 400L)
+ assert(manager.tryToAcquire(400L) === 400L)
+ assert(manager.tryToAcquire(200L) === 100L)
+ assert(manager.tryToAcquire(100L) === 0L)
+ assert(manager.tryToAcquire(100L) === 0L)
+
+ manager.release(500L)
+ assert(manager.tryToAcquire(300L) === 300L)
+ assert(manager.tryToAcquire(300L) === 200L)
+
+ manager.releaseMemoryForThisThread()
+ assert(manager.tryToAcquire(1000L) === 1000L)
+ assert(manager.tryToAcquire(100L) === 0L)
+ }
+
+ test("two threads requesting full memory") {
+ // Two threads request 500 bytes first, wait for each other to get it, and then request
+ // 500 more; we should immediately return 0 as both are now at 1 / N
+
+ val manager = new ShuffleMemoryManager(1000L)
+
+ class State {
+ var t1Result1 = -1L
+ var t2Result1 = -1L
+ var t1Result2 = -1L
+ var t2Result2 = -1L
+ }
+ val state = new State
+
+ val t1 = startThread("t1") {
+ val r1 = manager.tryToAcquire(500L)
+ state.synchronized {
+ state.t1Result1 = r1
+ state.notifyAll()
+ while (state.t2Result1 === -1L) {
+ state.wait()
+ }
+ }
+ val r2 = manager.tryToAcquire(500L)
+ state.synchronized { state.t1Result2 = r2 }
+ }
+
+ val t2 = startThread("t2") {
+ val r1 = manager.tryToAcquire(500L)
+ state.synchronized {
+ state.t2Result1 = r1
+ state.notifyAll()
+ while (state.t1Result1 === -1L) {
+ state.wait()
+ }
+ }
+ val r2 = manager.tryToAcquire(500L)
+ state.synchronized { state.t2Result2 = r2 }
+ }
+
+ failAfter(20 seconds) {
+ t1.join()
+ t2.join()
+ }
+
+ assert(state.t1Result1 === 500L)
+ assert(state.t2Result1 === 500L)
+ assert(state.t1Result2 === 0L)
+ assert(state.t2Result2 === 0L)
+ }
+
+
+ test("threads cannot grow past 1 / N") {
+ // Two threads request 250 bytes first, wait for each other to get it, and then request
+ // 500 more; we should only grant 250 bytes to each of them on this second request
+
+ val manager = new ShuffleMemoryManager(1000L)
+
+ class State {
+ var t1Result1 = -1L
+ var t2Result1 = -1L
+ var t1Result2 = -1L
+ var t2Result2 = -1L
+ }
+ val state = new State
+
+ val t1 = startThread("t1") {
+ val r1 = manager.tryToAcquire(250L)
+ state.synchronized {
+ state.t1Result1 = r1
+ state.notifyAll()
+ while (state.t2Result1 === -1L) {
+ state.wait()
+ }
+ }
+ val r2 = manager.tryToAcquire(500L)
+ state.synchronized { state.t1Result2 = r2 }
+ }
+
+ val t2 = startThread("t2") {
+ val r1 = manager.tryToAcquire(250L)
+ state.synchronized {
+ state.t2Result1 = r1
+ state.notifyAll()
+ while (state.t1Result1 === -1L) {
+ state.wait()
+ }
+ }
+ val r2 = manager.tryToAcquire(500L)
+ state.synchronized { state.t2Result2 = r2 }
+ }
+
+ failAfter(20 seconds) {
+ t1.join()
+ t2.join()
+ }
+
+ assert(state.t1Result1 === 250L)
+ assert(state.t2Result1 === 250L)
+ assert(state.t1Result2 === 250L)
+ assert(state.t2Result2 === 250L)
+ }
+
+ test("threads can block to get at least 1 / 2N memory") {
+ // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
+ // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests
+ // by t2 will return false right away because it now has 1 / 2N of the memory.
+
+ val manager = new ShuffleMemoryManager(1000L)
+
+ class State {
+ var t1Requested = false
+ var t2Requested = false
+ var t1Result = -1L
+ var t2Result = -1L
+ var t2Result2 = -1L
+ var t2WaitTime = 0L
+ }
+ val state = new State
+
+ val t1 = startThread("t1") {
+ state.synchronized {
+ state.t1Result = manager.tryToAcquire(1000L)
+ state.t1Requested = true
+ state.notifyAll()
+ while (!state.t2Requested) {
+ state.wait()
+ }
+ }
+ // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
+ // sure the other thread blocks for some time otherwise
+ Thread.sleep(300)
+ manager.release(250L)
+ }
+
+ val t2 = startThread("t2") {
+ state.synchronized {
+ while (!state.t1Requested) {
+ state.wait()
+ }
+ state.t2Requested = true
+ state.notifyAll()
+ }
+ val startTime = System.currentTimeMillis()
+ val result = manager.tryToAcquire(250L)
+ val endTime = System.currentTimeMillis()
+ state.synchronized {
+ state.t2Result = result
+ // A second call should return 0 because we're now already at 1 / 2N
+ state.t2Result2 = manager.tryToAcquire(100L)
+ state.t2WaitTime = endTime - startTime
+ }
+ }
+
+ failAfter(20 seconds) {
+ t1.join()
+ t2.join()
+ }
+
+ // Both threads should've been able to acquire their memory; the second one will have waited
+ // until the first one acquired 1000 bytes and then released 250
+ state.synchronized {
+ assert(state.t1Result === 1000L, "t1 could not allocate memory")
+ assert(state.t2Result === 250L, "t2 could not allocate memory")
+ assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
+ assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
+ }
+ }
+
+ test("releaseMemoryForThisThread") {
+ // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
+ // for a bit and releases all its memory. t2 should now be able to grab all the memory.
+
+ val manager = new ShuffleMemoryManager(1000L)
+
+ class State {
+ var t1Requested = false
+ var t2Requested = false
+ var t1Result = -1L
+ var t2Result1 = -1L
+ var t2Result2 = -1L
+ var t2Result3 = -1L
+ var t2WaitTime = 0L
+ }
+ val state = new State
+
+ val t1 = startThread("t1") {
+ state.synchronized {
+ state.t1Result = manager.tryToAcquire(1000L)
+ state.t1Requested = true
+ state.notifyAll()
+ while (!state.t2Requested) {
+ state.wait()
+ }
+ }
+ // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
+ // sure the other thread blocks for some time otherwise
+ Thread.sleep(300)
+ manager.releaseMemoryForThisThread()
+ }
+
+ val t2 = startThread("t2") {
+ state.synchronized {
+ while (!state.t1Requested) {
+ state.wait()
+ }
+ state.t2Requested = true
+ state.notifyAll()
+ }
+ val startTime = System.currentTimeMillis()
+ val r1 = manager.tryToAcquire(500L)
+ val endTime = System.currentTimeMillis()
+ val r2 = manager.tryToAcquire(500L)
+ val r3 = manager.tryToAcquire(500L)
+ state.synchronized {
+ state.t2Result1 = r1
+ state.t2Result2 = r2
+ state.t2Result3 = r3
+ state.t2WaitTime = endTime - startTime
+ }
+ }
+
+ failAfter(20 seconds) {
+ t1.join()
+ t2.join()
+ }
+
+ // Both threads should've been able to acquire their memory; the second one will have waited
+ // until the first one acquired 1000 bytes and then released all of it
+ state.synchronized {
+ assert(state.t1Result === 1000L, "t1 could not allocate memory")
+ assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
+ assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
+ assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
+ assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
+ }
+ }
+}