aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AsyncAnalysis.scala
blob: 9184960bc251485796dd60fdbb4cf50439c01eda (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
/*
 * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
 */

package scala.async

import scala.reflect.macros.Context
import scala.collection.mutable

private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
  import c.universe._

  val utils = TransformUtils[c.type](c)

  import utils._

  /**
   * Analyze the contents of an `async` block in order to:
   * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
   *
   * Must be called on the original tree, not on the ANF transformed tree.
   */
  def reportUnsupportedAwaits(tree: Tree): Boolean = {
    val analyzer = new UnsupportedAwaitAnalyzer
    analyzer.traverse(tree)
    analyzer.hasUnsupportedAwaits
  }

  /**
   * Analyze the contents of an `async` block in order to:
   * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
   * on whether or not they are accessed only from a single state.
   *
   * Must be called on the ANF transformed tree.
   */
  def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
    val analyzer = new AsyncDefinitionUseAnalyzer
    analyzer.traverse(tree)
    val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
    liftable
  }

  private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
    var hasUnsupportedAwaits = false

    override def nestedClass(classDef: ClassDef) {
      val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
      if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
        // do not allow local class definitions, because of SI-5467 (specific to case classes, though)
        if (classDef.symbol.asClass.isCaseClass)
          c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
      }
    }

    override def nestedModule(module: ModuleDef) {
      if (!reportUnsupportedAwait(module, "nested object")) {
        // local object definitions lead to spurious type errors (because of resetAllAttrs?)
        c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
      }
    }

    override def nestedMethod(module: DefDef) {
      reportUnsupportedAwait(module, "nested method")
    }

    override def byNameArgument(arg: Tree) {
      reportUnsupportedAwait(arg, "by-name argument")
    }

    override def function(function: Function) {
      reportUnsupportedAwait(function, "nested function")
    }

    override def patMatFunction(tree: Match) {
      reportUnsupportedAwait(tree, "nested function")
    }

    override def traverse(tree: Tree) {
      def containsAwait = tree exists isAwait
      tree match {
        case Try(_, _, _) if containsAwait                    =>
          reportUnsupportedAwait(tree, "try/catch")
          super.traverse(tree)
        case Return(_)                                        =>
          c.abort(tree.pos, "return is illegal within a async block")
        case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
          c.abort(tree.pos, "lazy vals are illegal within an async block")
        case _                                                =>
          super.traverse(tree)
      }
    }

    /**
     * @return true, if the tree contained an unsupported await.
     */
    private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
      val badAwaits: List[RefTree] = tree collect {
        case rt: RefTree if isAwait(rt) => rt
      }
      badAwaits foreach {
        tree =>
          reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
      }
      badAwaits.nonEmpty
    }

    private def reportError(pos: Position, msg: String) {
      hasUnsupportedAwaits = true
      if (!asyncBase.fallbackEnabled)
        c.error(pos, msg)
    }
  }

  private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
    private var chunkId = 0

    private def nextChunk() = chunkId += 1

    private var valDefChunkId = Map[Symbol, (ValDef, Int)]()

    val valDefsToLift      : mutable.Set[ValDef] = collection.mutable.Set()
    val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()

    override def nestedMethod(defDef: DefDef) {
      nestedMethodsToLift += defDef
      markReferencedVals(defDef)
    }

    override def function(function: Function) {
      markReferencedVals(function)
    }

    override def patMatFunction(tree: Match) {
      markReferencedVals(tree)
    }

    private def markReferencedVals(tree: Tree) {
      tree foreach {
        case rt: RefTree =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) =>
              valDefsToLift += vd // lift all vals referred to by nested functions.
            case _                      =>
          }
        case _           =>
      }
    }

    override def traverse(tree: Tree) = {
      tree match {
        case If(cond, thenp, elsep) if tree exists isAwait     =>
          traverseChunks(List(cond, thenp, elsep))
        case Match(selector, cases) if tree exists isAwait     =>
          traverseChunks(selector :: cases)
        case LabelDef(name, params, rhs) if rhs exists isAwait =>
          traverseChunks(rhs :: Nil)
        case Apply(fun, args) if isAwait(fun)                  =>
          super.traverse(tree)
          nextChunk()
        case vd: ValDef                                        =>
          super.traverse(tree)
          valDefChunkId += (vd.symbol ->(vd, chunkId))
          val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
          if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
        case as: Assign                                        =>
          if (isAwait(as.rhs)) {
            assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)

            // TODO test the orElse case, try to remove the restriction.
            val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
            valDefsToLift += vd
          }
          super.traverse(tree)
        case rt: RefTree                                       =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) if defChunkId != chunkId =>
              valDefsToLift += vd
            case _                                               =>
          }
          super.traverse(tree)
        case _                                                 => super.traverse(tree)
      }
    }

    private def traverseChunks(trees: List[Tree]) {
      trees.foreach {
        t => traverse(t); nextChunk()
      }
    }
  }

}