aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal/LiveVariables.scala
blob: db1501500a0894356193af0c2c8689aaf288b4c5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
package scala.async.internal

trait LiveVariables {
  self: AsyncMacro =>
  import c.universe._
  import Flag._

  /**
   *  Returns for a given state a list of fields (as trees) that should be nulled out
   *  upon resuming that state (at the beginning of `resume`).
   *
   *  @param   asyncStates the states of an `async` block
   *  @param   liftables   the lifted fields
   *  @return  a map mapping a state to the fields that should be nulled out
   *           upon resuming that state
   */
  def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = {
    // live variables analysis:
    // the result map indicates in which states a given field should be nulled out
    val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables)

    var assignsOf = Map[Int, List[Tree]]()

    for ((fld, where) <- liveVarsMap; state <- where)
      assignsOf get state match {
        case None =>
          assignsOf += (state -> List(fld))
        case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
          assignsOf += (state -> (fld +: trees))
        case _ =>
          /* do nothing */
      }

    assignsOf
  }

  /**
   *  Live variables data-flow analysis.
   *
   *  The goal is to find, for each lifted field, the last state where the field is used.
   *  In all direct successor states which are not (indirect) predecessors of that last state
   *  (possible through loops), the corresponding field should be nulled out (at the beginning of
   *  `resume`).
   *
   *  @param   asyncStates the states of an `async` block
   *  @param   liftables   the lifted fields
   *  @return              a map which indicates for a given field (the key) the states in which it should be nulled out
   */
  def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = {
    val liftedSyms: Set[Symbol] = // include only vars
      liftables.filter {
        case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
        case _ => false
      }.map(_.symbol).toSet

    // determine which fields should be live also at the end (will not be nulled out)
    val noNull: Set[Symbol] = liftedSyms.filter { sym =>
      val typeSym = tpe(sym).typeSymbol
      (typeSym.isClass && (typeSym.asClass.isPrimitive || typeSym == definitions.NothingClass)) || liftables.exists { tree =>
        !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym)
      }
    }
    AsyncUtils.vprintln(s"fields never zero-ed out: ${noNull.mkString(", ")}")

    /**
     *  Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields.
     *
     *  @param  as  a state of an `async` expression
     *  @return     a set of lifted fields that are used within state `as`
     */
    def fieldsUsedIn(as: AsyncState): ReferencedFields = {
      class FindUseTraverser extends AsyncTraverser {
        var usedFields = Set[Symbol]()
        var capturedFields = Set[Symbol]()
        private def capturing[A](body: => A): A = {
          val saved = capturing
          try {
            capturing = true
            body
          } finally capturing = saved
        }
        private def capturingCheck(tree: Tree) = capturing(tree foreach check)
        private var capturing: Boolean = false
        private def check(tree: Tree) {
          tree match {
            case Ident(_) if liftedSyms(tree.symbol) =>
              if (capturing)
                capturedFields += tree.symbol
              else
                usedFields += tree.symbol
            case _ =>
          }
        }
        override def traverse(tree: Tree) = {
          check(tree)
          super.traverse(tree)
        }

        override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef)

        override def nestedModule(module: ModuleDef): Unit = capturingCheck(module)

        override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef)

        override def byNameArgument(arg: Tree): Unit = capturingCheck(arg)

        override def function(function: Function): Unit = capturingCheck(function)

        override def patMatFunction(tree: Match): Unit = capturingCheck(tree)
      }

      val findUses = new FindUseTraverser
      findUses.traverse(Block(as.stats: _*))
      ReferencedFields(findUses.usedFields, findUses.capturedFields)
    }
    case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) {
      override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}"
    }

    /* Build the control-flow graph.
     *
     * A state `i` is contained in the list that is the value to which
     * key `j` maps iff control can flow from state `j` to state `i`.
     */
    val cfg: Map[Int, List[Int]] = asyncStates.map(as => (as.state -> as.nextStates)).toMap

    /** Tests if `state1` is a predecessor of `state2`.
     */
    def isPred(state1: Int, state2: Int): Boolean = {
      val seen = scala.collection.mutable.HashSet[Int]()

      def isPred0(state1: Int, state2: Int): Boolean = 
        if(state1 == state2) false
        else if (seen(state1)) false  // breaks cycles in the CFG
        else cfg get state1 match {
          case Some(nextStates) =>
            seen += state1
            nextStates.contains(state2) || nextStates.exists(isPred0(_, state2))
          case None =>
            false
        }

      isPred0(state1, state2)
    }

    val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get

    for (as <- asyncStates)
      AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}")

    /* Backwards data-flow analysis. Computes live variables information at entry and exit
     * of each async state.
     *
     * Compute using a simple fixed point iteration:
     *
     * 1. currStates = List(finalState)
     * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs
     * 3. record if LVentry(cs) has changed for some cs.
     * 4. obtain predecessors pred of each cs \in currStates
     * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors
     * 6. currStates = pred
     * 7. repeat if something has changed
     */

    var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()
    var LVexit  = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()

    // All fields are declared to be dead at the exit of the final async state, except for the ones
    // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def.
    LVexit = LVexit + (finalState.state -> noNull)

    var currStates = List(finalState)    // start at final state
    var captured: Set[Symbol] = Set()

    while (!currStates.isEmpty) {
      var entryChanged: List[AsyncState] = Nil

      for (cs <- currStates) {
        val LVentryOld = LVentry(cs.state)
        val referenced = fieldsUsedIn(cs)
        captured ++= referenced.captured
        val LVentryNew = LVexit(cs.state) ++ referenced.used
        if (!LVentryNew.sameElements(LVentryOld)) {
          LVentry = LVentry + (cs.state -> LVentryNew)
          entryChanged ::= cs
        }
      }

      val pred = entryChanged.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state)))
      var exitChanged: List[AsyncState] = Nil

      for (p <- pred) {
        val LVexitOld = LVexit(p.state)
        val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet
        if (!LVexitNew.sameElements(LVexitOld)) {
          LVexit = LVexit + (p.state -> LVexitNew)
          exitChanged ::= p
        }
      }

      currStates = exitChanged
    }

    for (as <- asyncStates) {
      AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}")
      AsyncUtils.vprintln(s"LVexit  at state #${as.state}: ${LVexit(as.state).mkString(", ")}")
    }

    def lastUsagesOf(field: Tree, at: AsyncState): Set[Int] = {
      val avoid = scala.collection.mutable.HashSet[AsyncState]()

      def lastUsagesOf0(field: Tree, at: AsyncState): Set[Int] = {
        if (avoid(at)) Set()
        else if (captured(field.symbol)) {
          Set()
        }
        else LVentry get at.state match {
          case Some(fields) if fields.exists(_ == field.symbol) =>
            Set(at.state)
          case _ =>
            avoid += at
            val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet
            preds.flatMap(p => lastUsagesOf0(field, p))
        }
      }

      lastUsagesOf0(field, at)
    }

    val lastUsages: Map[Tree, Set[Int]] =
      liftables.map(fld => (fld -> lastUsagesOf(fld, finalState))).toMap

    for ((fld, lastStates) <- lastUsages)
      AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}")

    val nullOutAt: Map[Tree, Set[Int]] =
      for ((fld, lastStates) <- lastUsages) yield {
        val killAt = lastStates.flatMap { s =>
          if (s == finalState.state) Set()
          else {
            val lastAsyncState = asyncStates.find(_.state == s).get
            val succNums       = lastAsyncState.nextStates
            // all successor states that are not indirect predecessors
            // filter out successor states where the field is live at the entry
            succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).exists(_ == fld.symbol))
          }
        }
        (fld, killAt)
      }

    for ((fld, killAt) <- nullOutAt)
      AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}")

    nullOutAt
  }
}