1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
|
package scala.concurrent.impl
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.CountDownLatch
import org.junit.Assert._
import org.junit.{ After, Before, Test }
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import scala.annotation.tailrec
import scala.concurrent.ExecutionContext
import scala.concurrent.impl.Promise.DefaultPromise
import scala.util.{ Failure, Success, Try }
import scala.util.control.NonFatal
/** Tests for the private class DefaultPromise */
@RunWith(classOf[JUnit4])
class DefaultPromiseTest {
// Many tests in this class use a helper class, Tester, to track the state of
// promises and to ensure they behave correctly, particularly the complex behaviour
// of linking.
type Result = Int
type PromiseId = Int
type HandlerId = Int
type ChainId = Int
/** The state of a set of set of linked promises. */
case class Chain(
promises: Set[PromiseId],
state: Either[Set[HandlerId],Try[Result]]
)
/** A helper class that provides methods for creating, linking, completing and
* adding handlers to promises. With each operation it verifies that handlers
* are called, any expected exceptions are thrown, and that all promises have
* the expected value.
*
* The links between promises are not tracked precisely. Instead, linked promises
* are placed in the same Chain object. Each link in the same chain will share
* the same value.
*/
class Tester {
var promises = Map.empty[PromiseId, DefaultPromise[Result]]
var chains = Map.empty[ChainId, Chain]
private var counter = 0
private def freshId(): Int = {
val id = counter
counter += 1
id
}
/** Handlers report their activity on this queue */
private val handlerQueue = new ConcurrentLinkedQueue[(Try[Result], HandlerId)]()
/** Get the chain for a given promise */
private def promiseChain(p: PromiseId): Option[(ChainId, Chain)] = {
val found: Iterable[(ChainId, Chain)] = for ((cid, c) <- chains; p0 <- c.promises; if (p0 == p)) yield ((cid, c))
found.toList match {
case Nil => None
case x::Nil => Some(x)
case _ => throw new IllegalStateException(s"Promise $p found in more than one chain")
}
}
/** Passed to `checkEffect` to indicate the expected effect of an operation */
sealed trait Effect
case object NoEffect extends Effect
case class HandlersFired(result: Try[Result], handlers: Set[HandlerId]) extends Effect
case object MaybeIllegalThrown extends Effect
case object IllegalThrown extends Effect
/** Runs an operation while verifying that the operation has the expected effect */
private def checkEffect(expected: Effect)(f: => Any) {
assert(handlerQueue.isEmpty()) // Should have been cleared by last usage
val result = Try(f)
var fireCounts = Map.empty[(Try[Result], HandlerId), Int]
while (!handlerQueue.isEmpty()) {
val key = handlerQueue.poll()
val newCount = fireCounts.getOrElse(key, 0) + 1
fireCounts = fireCounts.updated(key, newCount)
}
def assertIllegalResult = result match {
case Failure(e: IllegalStateException) => ()
case _ => fail(s"Expected IllegalStateException: $result")
}
expected match {
case NoEffect =>
assertTrue(s"Shouldn't throw exception: $result", result.isSuccess)
assertEquals(Map.empty[(Try[Result], HandlerId), Int], fireCounts)
case HandlersFired(firingResult, handlers) =>
assert(result.isSuccess)
val expectedCounts = handlers.foldLeft(Map.empty[(Try[Result], HandlerId), Int]) {
case (map, hid) => map.updated((firingResult, hid), 1)
}
assertEquals(expectedCounts, fireCounts)
case MaybeIllegalThrown =>
if (result.isFailure) assertIllegalResult
assertEquals(Map.empty, fireCounts)
case IllegalThrown =>
assertIllegalResult
assertEquals(Map.empty, fireCounts)
}
}
/** Check each promise has the expected value. */
private def assertPromiseValues() {
for ((cid, chain) <- chains; p <- chain.promises) {
chain.state match {
case Right(result) => assertEquals(Some(result), promises(p).value)
case Left(_) => ()
}
}
}
/** Create a promise, returning a handle. */
def newPromise(): PromiseId = {
val pid = freshId()
val cid = freshId()
promises = promises.updated(pid, new DefaultPromise[Result]())
chains = chains.updated(cid, Chain(Set(pid), Left(Set.empty)))
assertPromiseValues()
pid
}
/** Complete a promise */
def complete(p: PromiseId) {
val r = Success(freshId())
val (cid, chain) = promiseChain(p).get
val (completionEffect, newState) = chain.state match {
case Left(handlers) => (HandlersFired(r, handlers), Right(r))
case Right(completion) => (IllegalThrown, chain.state)
}
checkEffect(completionEffect) { promises(p).complete(r) }
chains = chains.updated(cid, chain.copy(state = newState))
assertPromiseValues()
}
/** Attempt to link two promises together */
def link(a: PromiseId, b: PromiseId): (ChainId, ChainId) = {
val promiseA = promises(a)
val promiseB = promises(b)
val (cidA, chainA) = promiseChain(a).get
val (cidB, chainB) = promiseChain(b).get
// Examine the state of each promise's chain to work out
// the effect of linking the promises, and to work out
// if the two chains should be merged.
sealed trait MergeOp
case object NoMerge extends MergeOp
case class Merge(state: Either[Set[HandlerId],Try[Result]]) extends MergeOp
val (linkEffect, mergeOp) = (chainA.state, chainB.state) match {
case (Left(handlers1), Left(handlers2)) =>
(NoEffect, Merge(Left(handlers1 ++ handlers2)))
case (Left(handlers), Right(result)) =>
(HandlersFired(result, handlers), Merge(Right(result)))
case (Right(result), Left(handlers)) =>
(HandlersFired(result, handlers), Merge(Right(result)))
case (Right(_), Right(_)) if (cidA == cidB) =>
(MaybeIllegalThrown, NoMerge) // Won't be thrown if happen to link a promise to itself
case (Right(_), Right(_)) =>
(IllegalThrown, NoMerge)
}
// Perform the linking and merge the chains, if appropriate
checkEffect(linkEffect) { promiseA.linkRootOf(promiseB) }
val (newCidA, newCidB) = mergeOp match {
case NoMerge => (cidA, cidB)
case Merge(newState) => {
chains = chains - cidA
chains = chains - cidB
val newCid = freshId()
chains = chains.updated(newCid, Chain(chainA.promises ++ chainB.promises, newState))
(newCid, newCid)
}
}
assertPromiseValues()
(newCidA, newCidB)
}
/** Attach an onComplete handler. When called, the handler will
* place an entry into `handlerQueue` with the handler's identity.
* This allows verification of handler calling semantics.
*/
def attachHandler(p: PromiseId): HandlerId = {
val hid = freshId()
val promise = promises(p)
val (cid, chain) = promiseChain(p).get
val (attachEffect, newState) = chain.state match {
case Left(handlers) =>
(NoEffect, Left(handlers + hid))
case Right(result) =>
(HandlersFired(result, Set(hid)), Right(result))
}
implicit val ec = new ExecutionContext {
def execute(r: Runnable) { r.run() }
def reportFailure(t: Throwable) { t.printStackTrace() }
}
checkEffect(attachEffect) { promise.onComplete(result => handlerQueue.add((result, hid))) }
chains = chains.updated(cid, chain.copy(state = newState))
assertPromiseValues()
hid
}
}
// Some methods and objects that build a list of promise
// actions to test and then execute them
type PromiseKey = Int
sealed trait Action
case class Complete(p: PromiseKey) extends Action
case class Link(a: PromiseKey, b: PromiseKey) extends Action
case class AttachHandler(p: PromiseKey) extends Action
/** Tests a sequence of actions on a Tester. Creates promises as needed. */
private def testActions(actions: Seq[Action]) {
val t = new Tester()
var pMap = Map.empty[PromiseKey, PromiseId]
def byKey(key: PromiseKey): PromiseId = {
if (!pMap.contains(key)) {
pMap = pMap.updated(key, t.newPromise())
}
pMap(key)
}
actions foreach { action =>
action match {
case Complete(p) => t.complete(byKey(p))
case Link(a, b) => t.link(byKey(a), byKey(b))
case AttachHandler(p) => t.attachHandler(byKey(p))
}
}
}
/** Tests all permutations of actions for `count` promises */
private def testPermutations(count: Int) {
val ps = (0 until count).toList
val pPairs = for (a <- ps; b <- ps) yield (a, b)
var allActions = ps.map(Complete(_)) ++ pPairs.map { case (a, b) => Link(a, b) } ++ ps.map(AttachHandler(_))
for ((permutation, i) <- allActions.permutations.zipWithIndex) {
testActions(permutation)
}
}
/** Test all permutations of actions with a single promise */
@Test
def testPermutations1 {
testPermutations(1)
}
/** Test all permutations of actions with two promises - about 40 thousand */
@Test
def testPermutations2 {
testPermutations(2)
}
/** Link promises in different orders, using the same link structure as is
* used in Future.flatMap */
@Test
def simulateFlatMapLinking {
val random = new scala.util.Random(1)
for (_ <- 0 until 10) {
val t = new Tester()
val flatMapCount = 100
sealed trait FlatMapEvent
case class Link(a: PromiseId, b: PromiseId) extends FlatMapEvent
case class Complete(p: PromiseId) extends FlatMapEvent
@tailrec
def flatMapEvents(count: Int, p1: PromiseId, acc: List[FlatMapEvent]): List[FlatMapEvent] = {
if (count == 0) {
Complete(p1)::acc
} else {
val p2 = t.newPromise()
flatMapEvents(count - 1, p2, Link(p2, p1)::acc)
}
}
val events = flatMapEvents(flatMapCount, t.newPromise(), Nil)
assertEquals(flatMapCount + 1, t.chains.size) // All promises are unlinked
val shuffled = random.shuffle(events)
shuffled foreach {
case Link(a, b) => t.link(a, b)
case Complete(p) => t.complete(p)
}
// All promises should be linked together, no matter the order of their linking
assertEquals(1, t.chains.size)
}
}
/** Link promises together on more than one thread, using the same link
* structure as is used in Future.flatMap */
@Test
def testFlatMapLinking {
for (_ <- 0 until 100) {
val flatMapCount = 100
val startLatch = new CountDownLatch(1)
val doneLatch = new CountDownLatch(flatMapCount + 1)
def execute(f: => Unit) {
val ec = ExecutionContext.global
ec.execute(new Runnable {
def run() {
try {
startLatch.await()
f
doneLatch.countDown()
} catch {
case NonFatal(e) => ec.reportFailure(e)
}
}
})
}
@tailrec
def flatMapTimes(count: Int, p1: DefaultPromise[Int]) {
if (count == 0) {
execute { p1.success(1) }
} else {
val p2 = new DefaultPromise[Int]()
execute { p2.linkRootOf(p1) }
flatMapTimes(count - 1, p2)
}
}
val p = new DefaultPromise[Int]()
flatMapTimes(flatMapCount, p)
startLatch.countDown()
doneLatch.await()
assertEquals(Some(Success(1)), p.value)
}
}
}
|