summaryrefslogtreecommitdiff
path: root/test/junit/scala/concurrent/impl/DefaultPromiseTest.scala
blob: f3a75e24d00ef12c44899e7eb350cf2e36ca424d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
package scala.concurrent.impl

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.CountDownLatch
import org.junit.Assert._
import org.junit.{ After, Before, Test }
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import scala.annotation.tailrec
import scala.concurrent.ExecutionContext
import scala.concurrent.impl.Promise.DefaultPromise
import scala.util.{ Failure, Success, Try }
import scala.util.control.NonFatal

/** Tests for the private class DefaultPromise */
@RunWith(classOf[JUnit4])
class DefaultPromiseTest {

  // Many tests in this class use a helper class, Tester, to track the state of
  // promises and to ensure they behave correctly, particularly the complex behaviour
  // of linking.

  type Result = Int
  type PromiseId = Int
  type HandlerId = Int
  type ChainId = Int

  /** The state of a set of set of linked promises. */
  case class Chain(
    promises: Set[PromiseId],
    state: Either[Set[HandlerId],Try[Result]]
  )

  /** A helper class that provides methods for creating, linking, completing and
   *  adding handlers to promises. With each operation it verifies that handlers
   *  are called, any expected exceptions are thrown, and that all promises have
   *  the expected value.
   *
   *  The links between promises are not tracked precisely. Instead, linked promises
   *  are placed in the same Chain object. Each link in the same chain will share
   *  the same value.
   */
  class Tester {
    var promises = Map.empty[PromiseId, DefaultPromise[Result]]
    var chains = Map.empty[ChainId, Chain]

    private var counter = 0
    private def freshId(): Int = {
      val id = counter
      counter += 1
      id
    }

    /** Handlers report their activity on this queue */
    private val handlerQueue = new ConcurrentLinkedQueue[(Try[Result], HandlerId)]()

    /** Get the chain for a given promise */
    private def promiseChain(p: PromiseId): Option[(ChainId, Chain)] = {
      val found: Iterable[(ChainId, Chain)] = for ((cid, c) <- chains; p0 <- c.promises; if (p0 == p)) yield ((cid, c))
      found.toList match {
        case Nil => None
        case x::Nil => Some(x)
        case _ => throw new IllegalStateException(s"Promise $p found in more than one chain")
      }
    }

    /** Passed to `checkEffect` to indicate the expected effect of an operation */
    sealed trait Effect
    case object NoEffect extends Effect
    case class HandlersFired(result: Try[Result], handlers: Set[HandlerId]) extends Effect
    case object MaybeIllegalThrown extends Effect
    case object IllegalThrown extends Effect

    /** Runs an operation while verifying that the operation has the expected effect */
    private def checkEffect(expected: Effect)(f: => Any) {
      assert(handlerQueue.isEmpty()) // Should have been cleared by last usage
      val result = Try(f)

      var fireCounts = Map.empty[(Try[Result], HandlerId), Int]
      while (!handlerQueue.isEmpty()) {
        val key = handlerQueue.poll()
        val newCount = fireCounts.getOrElse(key, 0) + 1
        fireCounts = fireCounts.updated(key, newCount)
      }

      def assertIllegalResult = result match {
        case Failure(e: IllegalStateException) => ()
        case _ => fail(s"Expected IllegalStateException: $result")
      }

      expected match {
        case NoEffect =>
          assertTrue(s"Shouldn't throw exception: $result", result.isSuccess)
          assertEquals(Map.empty[(Try[Result], HandlerId), Int], fireCounts)
        case HandlersFired(firingResult, handlers) =>
          assert(result.isSuccess)
          val expectedCounts = handlers.foldLeft(Map.empty[(Try[Result], HandlerId), Int]) {
            case (map, hid) => map.updated((firingResult, hid), 1)
          }
          assertEquals(expectedCounts, fireCounts)
        case MaybeIllegalThrown =>
          if (result.isFailure) assertIllegalResult
          assertEquals(Map.empty, fireCounts)
        case IllegalThrown =>
          assertIllegalResult
          assertEquals(Map.empty, fireCounts)
      }
    }

    /** Check each promise has the expected value. */
    private def assertPromiseValues() {
      for ((cid, chain) <- chains; p <- chain.promises) {
        chain.state match {
          case Right(result) => assertEquals(Some(result), promises(p).value)
          case Left(_) => ()
        }
      }
    }

    /** Create a promise, returning a handle. */
    def newPromise(): PromiseId = {
      val pid = freshId()
      val cid = freshId()
      promises = promises.updated(pid, new DefaultPromise[Result]())
      chains = chains.updated(cid, Chain(Set(pid), Left(Set.empty)))
      assertPromiseValues()
      pid
    }

    /** Complete a promise */
    def complete(p: PromiseId) {
      val r = Success(freshId())
      val (cid, chain) = promiseChain(p).get
      val (completionEffect, newState) = chain.state match {
        case Left(handlers) => (HandlersFired(r, handlers), Right(r))
        case Right(completion) => (IllegalThrown, chain.state)
      }
      checkEffect(completionEffect) { promises(p).complete(r) }
      chains = chains.updated(cid, chain.copy(state = newState))
      assertPromiseValues()
    }

    /** Attempt to link two promises together */
    def link(a: PromiseId, b: PromiseId): (ChainId, ChainId) = {
      val promiseA = promises(a)
      val promiseB = promises(b)
      val (cidA, chainA) = promiseChain(a).get
      val (cidB, chainB) = promiseChain(b).get

      // Examine the state of each promise's chain to work out
      // the effect of linking the promises, and to work out
      // if the two chains should be merged.

      sealed trait MergeOp
      case object NoMerge extends MergeOp
      case class Merge(state: Either[Set[HandlerId],Try[Result]]) extends MergeOp

      val (linkEffect, mergeOp) = (chainA.state, chainB.state) match {
        case (Left(handlers1), Left(handlers2)) =>
          (NoEffect, Merge(Left(handlers1 ++ handlers2)))
        case (Left(handlers), Right(result)) =>
          (HandlersFired(result, handlers), Merge(Right(result)))
        case (Right(result), Left(handlers)) =>
          (HandlersFired(result, handlers), Merge(Right(result)))
        case (Right(_), Right(_)) if (cidA == cidB) =>
          (MaybeIllegalThrown, NoMerge) // Won't be thrown if happen to link a promise to itself
        case (Right(_), Right(_)) =>
          (IllegalThrown, NoMerge)
      }

      // Perform the linking and merge the chains, if appropriate

      checkEffect(linkEffect) { promiseA.linkRootOf(promiseB) }

      val (newCidA, newCidB) = mergeOp match {
        case NoMerge => (cidA, cidB)
        case Merge(newState) => {
          chains = chains - cidA
          chains = chains - cidB
          val newCid = freshId()
          chains = chains.updated(newCid, Chain(chainA.promises ++ chainB.promises, newState))
          (newCid, newCid)
        }
      }
      assertPromiseValues()
      (newCidA, newCidB)
    }

    /** Attach an onComplete handler. When called, the handler will
     *  place an entry into `handlerQueue` with the handler's identity.
     *  This allows verification of handler calling semantics.
     */
    def attachHandler(p: PromiseId): HandlerId = {
      val hid = freshId()
      val promise = promises(p)
      val (cid, chain) = promiseChain(p).get
      val (attachEffect, newState) = chain.state match {
        case Left(handlers) =>
          (NoEffect, Left(handlers + hid))
        case Right(result) =>
          (HandlersFired(result, Set(hid)), Right(result))
      }
      implicit val ec = new ExecutionContext {
        def execute(r: Runnable) { r.run() }
        def reportFailure(t: Throwable) { t.printStackTrace() }
      }

      checkEffect(attachEffect) { promise.onComplete(result => handlerQueue.add((result, hid))) }
      chains = chains.updated(cid, chain.copy(state = newState))
      assertPromiseValues()
      hid
    }
  }

  // Some methods and objects that build a list of promise
  // actions to test and then execute them

  type PromiseKey = Int

  sealed trait Action
  case class Complete(p: PromiseKey) extends Action
  case class Link(a: PromiseKey, b: PromiseKey) extends Action
  case class AttachHandler(p: PromiseKey) extends Action

  /** Tests a sequence of actions on a Tester. Creates promises as needed. */
  private def testActions(actions: Seq[Action]) {
    val t = new Tester()
    var pMap = Map.empty[PromiseKey, PromiseId]
    def byKey(key: PromiseKey): PromiseId = {
      if (!pMap.contains(key)) {
        pMap = pMap.updated(key, t.newPromise())
      }
      pMap(key)
    }

    actions foreach { action =>
      action match {
        case Complete(p) => t.complete(byKey(p))
        case Link(a, b) => t.link(byKey(a), byKey(b))
        case AttachHandler(p) => t.attachHandler(byKey(p))
      }
    }
  }

  /** Tests all permutations of actions for `count` promises */
  private def testPermutations(count: Int) {
    val ps = (0 until count).toList
    val pPairs = for (a <- ps; b <- ps) yield (a, b)

    var allActions = ps.map(Complete(_)) ++ pPairs.map { case (a, b) => Link(a, b) } ++ ps.map(AttachHandler(_))
    for ((permutation, i) <- allActions.permutations.zipWithIndex) {
      testActions(permutation)
    }
  }

  /** Test all permutations of actions with a single promise */
  @Test
  def testPermutations1 {
    testPermutations(1)
  }

  /** Test all permutations of actions with two promises - about 40 thousand */
  @Test
  def testPermutations2 {
    testPermutations(2)
  }

  /** Link promises in different orders, using the same link structure as is
   *  used in Future.flatMap */
  @Test
  def simulateFlatMapLinking {
    val random = new scala.util.Random(1)
    for (_ <- 0 until 10) {
      val t = new Tester()
      val flatMapCount = 100

      sealed trait FlatMapEvent
      case class Link(a: PromiseId, b: PromiseId) extends FlatMapEvent
      case class Complete(p: PromiseId) extends FlatMapEvent

      @tailrec
      def flatMapEvents(count: Int, p1: PromiseId, acc: List[FlatMapEvent]): List[FlatMapEvent] = {
        if (count == 0) {
          Complete(p1)::acc
        } else {
          val p2 = t.newPromise()
          flatMapEvents(count - 1, p2, Link(p2, p1)::acc)
        }
      }

      val events = flatMapEvents(flatMapCount, t.newPromise(), Nil)
      assertEquals(flatMapCount + 1, t.chains.size) // All promises are unlinked
      val shuffled = random.shuffle(events)
      shuffled foreach {
        case Link(a, b) => t.link(a, b)
        case Complete(p) => t.complete(p)
      }
      // All promises should be linked together, no matter the order of their linking
      assertEquals(1, t.chains.size)
    }
  }

  /** Link promises together on more than one thread, using the same link
   *  structure as is used in Future.flatMap */
  @Test
  def testFlatMapLinking {
    for (_ <- 0 until 100) {
      val flatMapCount = 100
      val startLatch = new CountDownLatch(1)
      val doneLatch = new CountDownLatch(flatMapCount + 1)
      def execute(f: => Unit) {
        val ec = ExecutionContext.global
        ec.execute(new Runnable {
          def run() {
            try {
              startLatch.await()
              f
              doneLatch.countDown()
            } catch {
              case NonFatal(e) => ec.reportFailure(e)
            }
          }
        })
      }
      @tailrec
      def flatMapTimes(count: Int, p1: DefaultPromise[Int]) {
        if (count == 0) {
          execute { p1.success(1) }
        } else {
          val p2 = new DefaultPromise[Int]()
          execute { p2.linkRootOf(p1) }
          flatMapTimes(count - 1, p2)
        }
      }

      val p = new DefaultPromise[Int]()
      flatMapTimes(flatMapCount, p)
      startLatch.countDown()
      doneLatch.await()
      assertEquals(Some(Success(1)), p.value)
    }
  }

}