aboutsummaryrefslogtreecommitdiff
path: root/compiler/src/dotty/tools/dotc/transform/CapturedVars.scala
blob: 368250cdfc5e2e3f6e8e7330573f13554dbf9d6d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package dotty.tools.dotc
package transform

import TreeTransforms._
import core.DenotTransformers._
import core.Symbols._
import core.Contexts._
import core.Types._
import core.Flags._
import core.Decorators._
import core.SymDenotations._
import core.StdNames.nme
import core.Names._
import core.NameOps._
import core.NameExtractors.TempResultName
import ast.Trees._
import SymUtils._
import collection.{ mutable, immutable }
import collection.mutable.{ LinkedHashMap, LinkedHashSet, TreeSet }

class CapturedVars extends MiniPhase with IdentityDenotTransformer { thisTransform =>
  import ast.tpd._

  /** the following two members override abstract members in Transform */
  val phaseName: String = "capturedVars"
  val treeTransform = new Transform(Set())

  private class RefInfo(implicit ctx: Context) {
    /** The classes for which a Ref type exists. */
    val refClassKeys: collection.Set[Symbol] =
      defn.ScalaNumericValueClasses() + defn.BooleanClass + defn.ObjectClass

    val refClass: Map[Symbol, Symbol] =
      refClassKeys.map(rc => rc -> ctx.requiredClass(s"scala.runtime.${rc.name}Ref")).toMap

    val volatileRefClass: Map[Symbol, Symbol] =
      refClassKeys.map(rc => rc -> ctx.requiredClass(s"scala.runtime.Volatile${rc.name}Ref")).toMap

    val boxedRefClasses: collection.Set[Symbol] =
      refClassKeys.flatMap(k => Set(refClass(k), volatileRefClass(k)))
  }

  class Transform(captured: collection.Set[Symbol]) extends TreeTransform {
    def phase = thisTransform

    private var myRefInfo: RefInfo = null
    private def refInfo(implicit ctx: Context) = {
      if (myRefInfo == null) myRefInfo = new RefInfo()
      myRefInfo
    }

    private class CollectCaptured(implicit ctx: Context) extends EnclosingMethodTraverser {
      private val captured = mutable.HashSet[Symbol]()
      def traverse(enclMeth: Symbol, tree: Tree)(implicit ctx: Context) = tree match {
        case id: Ident =>
          val sym = id.symbol
          if (sym.is(Mutable, butNot = Method) && sym.owner.isTerm && sym.enclosingMethod != enclMeth) {
            ctx.log(i"capturing $sym in ${sym.enclosingMethod}, referenced from $enclMeth")
            captured += sym
          }
        case _ =>
          foldOver(enclMeth, tree)
      }
      def runOver(tree: Tree): collection.Set[Symbol] = {
        apply(NoSymbol, tree)
        captured
      }
    }

    override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
      val captured = (new CollectCaptured)(ctx.withPhase(thisTransform))
        .runOver(ctx.compilationUnit.tpdTree)
      new Transform(captured)
    }

    /** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`,
     *  depending on whether the reference should be @volatile
     */
    def refClass(cls: Symbol, isVolatile: Boolean)(implicit ctx: Context): Symbol = {
      val refMap = if (isVolatile) refInfo.volatileRefClass else refInfo.refClass
      if (cls.isClass)  {
        refMap.getOrElse(cls, refMap(defn.ObjectClass))
      }
      else refMap(defn.ObjectClass)
    }

    override def prepareForValDef(vdef: ValDef)(implicit ctx: Context) = {
      val sym = vdef.symbol
      if (captured contains sym) {
        val newd = sym.denot(ctx.withPhase(thisTransform)).copySymDenotation(
          info = refClass(sym.info.classSymbol, sym.hasAnnotation(defn.VolatileAnnot)).typeRef,
          initFlags = sym.flags &~ Mutable)
        newd.removeAnnotation(defn.VolatileAnnot)
        newd.installAfter(thisTransform)
      }
      this
    }

    override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
      val vble = vdef.symbol
      if (captured contains vble) {
        def boxMethod(name: TermName): Tree =
          ref(vble.info.classSymbol.companionModule.info.member(name).symbol)
        cpy.ValDef(vdef)(
          rhs = vdef.rhs match {
            case EmptyTree => boxMethod(nme.zero).appliedToNone.withPos(vdef.pos)
            case arg => boxMethod(nme.create).appliedTo(arg)
          },
          tpt = TypeTree(vble.info).withPos(vdef.tpt.pos))
      } else vdef
    }

    override def transformIdent(id: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = {
      val vble = id.symbol
      if (captured(vble))
        (id select nme.elem).ensureConforms(vble.denot(ctx.withPhase(thisTransform)).info)
      else id
    }

    /** If assignment is to a boxed ref type, e.g.
     *
     *      intRef.elem = expr
     *
     *  rewrite using a temporary var to
     *
     *      val ev$n = expr
     *      intRef.elem = ev$n
     *
     *  That way, we avoid the problem that `expr` might contain a `try` that would
     *  run on a non-empty stack (which is illegal under JVM rules). Note that LiftTry
     *  has already run before, so such `try`s would not be eliminated.
     *
     *  Also: If the ref type lhs is followed by a cast (can be an artifact of nested translation),
     *  drop the cast.
     */
    override def transformAssign(tree: Assign)(implicit ctx: Context, info: TransformerInfo): Tree = {
      def recur(lhs: Tree): Tree = lhs match {
        case TypeApply(Select(qual, nme.asInstanceOf_), _) =>
          val Select(_, nme.elem) = qual
          recur(qual)
        case Select(_, nme.elem) if refInfo.boxedRefClasses.contains(lhs.symbol.maybeOwner) =>
          val tempDef = transformFollowing(SyntheticValDef(TempResultName.fresh(), tree.rhs))
          transformFollowing(Block(tempDef :: Nil, cpy.Assign(tree)(lhs, ref(tempDef.symbol))))
        case _ =>
          tree
      }
      recur(tree.lhs)
    }
  }
}