aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AsyncAnalysis.scala
blob: c160776b841c56a213044adbb390c6f6c1b53416 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package scala.async

import scala.reflect.macros.Context
import collection.mutable

private[async] final case class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) {
  import c.universe._

  /**
   * Analyze the contents of an `async` block in order to:
   * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
   *
   * Must be called on the original tree, not on the ANF transformed tree.
   */
  def reportUnsupportedAwaits(tree: Tree) {
    new UnsupportedAwaitAnalyzer().traverse(tree)
  }

  /**
   * Analyze the contents of an `async` block in order to:
   * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
   * on whether or not they are accessed only from a single state.
   *
   * Must be called on the ANF transformed tree.
   */
  def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
    val analyzer = new AsyncDefinitionUseAnalyzer
    analyzer.traverse(tree)
    analyzer.valDefsToLift.toList
  }

  private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser {
    override def nestedClass(classDef: ClassDef) {
      val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
      if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
        // do not allow local class definitions, because of SI-5467 (specific to case classes, though)
        c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block")
      }
    }

    override def nestedModule(module: ModuleDef) {
      if (!reportUnsupportedAwait(module, "nested object")) {
        // local object definitions lead to spurious type errors (because of resetAllAttrs?)
        c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
      }
    }

    override def byNameArgument(arg: Tree) {
      reportUnsupportedAwait(arg, "by-name argument")
    }

    override def function(function: Function) {
      reportUnsupportedAwait(function, "nested function")
    }

    override def traverse(tree: Tree) {
      def containsAwait = tree exists isAwait
      tree match {
        case Try(_, _, _) if containsAwait =>
          reportUnsupportedAwait(tree, "try/catch")
        case _ =>
          super.traverse(tree)
      }
    }

    /**
     * @return true, if the tree contained an unsupported await.
     */
    private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
      val badAwaits: List[RefTree] = tree collect {
        case rt: RefTree if isAwait(rt) => rt
      }
      badAwaits foreach {
        tree =>
          c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
      }
      badAwaits.nonEmpty
   }
  }

  private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser {
    private var chunkId = 0

    private def nextChunk() = chunkId += 1

    private var valDefChunkId = Map[Symbol, (ValDef, Int)]()

    val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()

    override def traverse(tree: Tree) = {
      tree match {
        case If(cond, thenp, elsep) if tree exists isAwait =>
          traverseChunks(List(cond, thenp, elsep))
        case Match(selector, cases) if tree exists isAwait =>
          traverseChunks(selector :: cases)
        case Apply(fun, args) if isAwait(fun)              =>
          super.traverse(tree)
          nextChunk()
        case vd: ValDef                                    =>
          super.traverse(tree)
          valDefChunkId += (vd.symbol ->(vd, chunkId))
          if (isAwait(vd.rhs)) valDefsToLift += vd
        case as: Assign                                    =>
          if (isAwait(as.rhs)) {
            assert(as.symbol != null, "internal error: null symbol for Assign tree:" + as)

            // TODO test the orElse case, try to remove the restriction.
            val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol))
            valDefsToLift += vd
          }
          super.traverse(tree)
        case rt: RefTree                                   =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) if defChunkId != chunkId =>
              valDefsToLift += vd
            case _                                               =>
          }
          super.traverse(tree)
        case _                                             => super.traverse(tree)
      }
    }

    private def traverseChunks(trees: List[Tree]) {
      trees.foreach {
        t => traverse(t); nextChunk()
      }
    }
  }
}