aboutsummaryrefslogtreecommitdiff
path: root/tests/bench/transactional/ReaderMonadic.scala
blob: ce69c35ad591b455c0d63907112b53064993c512 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package transactional

case class Reader[R,A](run: R => A) {
  def map[B](f: A => B): Reader[R, B] = Reader(r => f(run(r)))
  def flatMap[B](f: A => Reader[R, B]): Reader[R, B] = Reader(r => f(run(r)).run(r))
}

object Reader {
  def ask[R]: Reader[R,R] = Reader(r => r)
}

object ReaderBench extends Benchmark {
  type Transactional[T] = Reader[Transaction, T]

  def transaction[T](op: Transactional[T]): T = {
    implicit val trans: Transaction = new Transaction
    val res = op.run(trans)
    trans.commit()
    res
  }

  def thisTransaction: Transactional[Transaction] = Reader.ask

  abstract class Op {
    def f(x: Int): Transactional[Int]
  }

  class Op0 extends Op {
    def f(x: Int): Transactional[Int] =
      for (trans <- thisTransaction)
      yield { trans.println("0th step"); x }
  }

  class Op1 extends Op {
    def f(x: Int): Transactional[Int] =
      for (trans <- thisTransaction)
      yield { trans.println("first step"); x + 1 }
  }

  class Op2 extends Op {
    def f(x: Int): Transactional[Int] =
      for (trans <- thisTransaction)
      yield { trans.println("second step"); x + 2 }
  }

  class Op3 extends Op {
    def f(x: Int): Transactional[Int] =
      for (trans <- thisTransaction)
      yield { trans.println("third step"); x + 3 }
  }

  val op = Array[Op](new Op0, new Op1, new Op2, new Op3)

  def f(x: Int, n: Int): Transactional[Int] = {
    def rest(trans: Transaction): Transactional[Int] = {
      trans.println("fourth step")
      if (n > 0) {
        for {
          y <- op(n % 4).f(x)
          z <- f(y: Int, n - 1)
        }
        yield z
      }
      else {
        if (x % 2 != 0)
          for (trans <- thisTransaction)
          yield { trans.abort(); () }
        Reader(_ => x)
      }
    }
    thisTransaction.flatMap(rest)
  }

  def run(): Int = {
    transaction {
      for (res <- f(7, 10))
      yield {
        for (trans <- thisTransaction)
        yield { assert(!trans.isAborted); () }
        assert(res == 22)
        res
      }
    }
  }
}

object ReaderMonadic extends Runner("reader monadic", ReaderBench, 22)