aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal/Lifter.scala
blob: ff9057688f8d4659aded7eb21d74d722b2b969c3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package scala.async.internal

trait Lifter {
  self: AsyncMacro =>
  import c.universe._
  import Flag._
  import c.internal._
  import decorators._

  /**
   * Identify which DefTrees are used (including transitively) which are declared
   * in some state but used (including transitively) in another state.
   *
   * These will need to be lifted to class members of the state machine.
   */
  def liftables(asyncStates: List[AsyncState]): List[Tree] = {
    object companionship {
      private val companions = collection.mutable.Map[Symbol, Symbol]()
      private val companionsInverse = collection.mutable.Map[Symbol, Symbol]()
      private def record(sym1: Symbol, sym2: Symbol): Unit = {
        companions(sym1) = sym2
        companions(sym2) = sym1
      }

      def record(defs: List[Tree]): Unit = {
        // Keep note of local companions so we rename them consistently
        // when lifting.
        val comps = for {
          cd@ClassDef(_, _, _, _) <- defs
          md@ModuleDef(_, _, _) <- defs
          if (cd.name.toTermName == md.name)
        } record(cd.symbol, md.symbol)
      }
      def companionOf(sym: Symbol): Symbol = {
        companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol)
      }
    }


    val defs: Map[Tree, Int] = {
      /** Collect the DefTrees directly enclosed within `t` that have the same owner */
      def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
        case ld: LabelDef => Nil
        case dt: DefTree => dt :: Nil
        case _: Function => Nil
        case t           =>
          val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_))
          companionship.record(childDefs)
          childDefs
      }
      asyncStates.flatMap {
        asyncState =>
          val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
          defs.map((_, asyncState.state))
      }.toMap
    }

    // In which block are these symbols defined?
    val symToDefiningState: Map[Symbol, Int] = defs.map {
      case (k, v) => (k.symbol, v)
    }

    // The definitions trees
    val symToTree: Map[Symbol, Tree] = defs.map {
      case (k, v) => (k.symbol, k)
    }

    // The direct references of each definition tree
    val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
      case tree => (tree.symbol, tree.collect {
        case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
      })
    }.toMap

    // The direct references of each block, excluding references of `DefTree`-s which
    // are already accounted for.
    val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
      val refs: List[(Int, Symbol)] = asyncStates.flatMap(
        asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
          case rt: RefTree
            if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
        })
      )
      toMultiMap(refs)
    }

    def liftableSyms: Set[Symbol] = {
      val liftableMutableSet = collection.mutable.Set[Symbol]()
      def markForLift(sym: Symbol): Unit = {
        if (!liftableMutableSet(sym)) {
          liftableMutableSet += sym

          // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars
          // stays in its original location, so things that it refers to need not be lifted.
          if (!(sym.isTerm && !sym.asTerm.isLazy && (sym.asTerm.isVal || sym.asTerm.isVar)))
            defSymToReferenced(sym).foreach(sym2 => markForLift(sym2))
        }
      }
      // Start things with DefTrees directly referenced from statements from other states...
      val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
        case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
      }
      // .. and likewise for DefTrees directly referenced by other DefTrees from other states
      val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
        case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
      }
      // Mark these for lifting, which will follow transitive references.
      (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
      liftableMutableSet.toSet
    }

    val lifted = liftableSyms.map(symToTree).toList.map {
      t =>
        val sym = t.symbol
        val treeLifted = t match {
          case vd@ValDef(_, _, tpt, rhs)                    =>
            sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL)
            sym.setName(name.fresh(sym.name.toTermName))
            sym.setInfo(deconst(sym.info))
            val rhs1 = if (sym.asTerm.isLazy) rhs else EmptyTree
            treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(tpe(sym)).setPos(t.pos), rhs1)
          case dd@DefDef(_, _, tparams, vparamss, tpt, rhs) =>
            sym.setName(this.name.fresh(sym.name.toTermName))
            sym.setFlag(PRIVATE | LOCAL)
            // Was `DefDef(sym, rhs)`, but this ran afoul of `ToughTypeSpec.nestedMethodWithInconsistencyTreeAndInfoParamSymbols`
            // due to the handling of type parameter skolems in `thisMethodType` in `Namers`
            treeCopy.DefDef(dd, Modifiers(sym.flags), sym.name, tparams, vparamss, tpt, rhs)
          case cd@ClassDef(_, _, tparams, impl)             =>
            sym.setName(newTypeName(name.fresh(sym.name.toString).toString))
            companionship.companionOf(cd.symbol) match {
              case NoSymbol     =>
              case moduleSymbol =>
                moduleSymbol.setName(sym.name.toTermName)
                moduleSymbol.asModule.moduleClass.setName(moduleSymbol.name.toTypeName)
            }
            treeCopy.ClassDef(cd, Modifiers(sym.flags), sym.name, tparams, impl)
          case md@ModuleDef(_, _, impl)                     =>
            companionship.companionOf(md.symbol) match {
              case NoSymbol    =>
                sym.setName(name.fresh(sym.name.toTermName))
                sym.asModule.moduleClass.setName(sym.name.toTypeName)
              case classSymbol => // will be renamed by `case ClassDef` above.
            }
            treeCopy.ModuleDef(md, Modifiers(sym.flags), sym.name, impl)
          case td@TypeDef(_, _, tparams, rhs)               =>
            sym.setName(newTypeName(name.fresh(sym.name.toString).toString))
            treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs)
        }
        atPos(t.pos)(treeLifted)
    }
    lifted
  }
}