aboutsummaryrefslogtreecommitdiff
path: root/src/dotty/tools/dotc/core/Phases.scala
blob: 7ae7b6ad3bcadbf6986eca19f98869a14dc90641 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
package dotty.tools.dotc
package core

import Periods._
import Contexts._
import util.DotClass
import DenotTransformers._
import Denotations._
import config.Printers._
import scala.collection.mutable.{ListBuffer, ArrayBuffer}
import dotty.tools.dotc.transform.TreeTransforms.{TreeTransformer, TreeTransform}
import dotty.tools.dotc.transform.PostTyperTransformers.PostTyperTransformer
import dotty.tools.dotc.transform.TreeTransforms
import TreeTransforms.Separator

trait Phases {
  self: Context =>

  import Phases._

  def phase: Phase = base.phases(period.phaseId)

  def phasesStack: List[Phase] =
    if ((this eq NoContext) || !phase.exists) Nil
    else phase :: outersIterator.dropWhile(_.phase == phase).next.phasesStack

  /** Execute `op` at given phase */
  def atPhase[T](phase: Phase)(op: Context => T): T =
    atPhase(phase.id)(op)

  def atNextPhase[T](op: Context => T): T = atPhase(phase.next)(op)

  def atPhaseNotLaterThan[T](limit: Phase)(op: Context => T): T =
    if (!limit.exists || phase <= limit) op(this) else atPhase(limit)(op)

  def atPhaseNotLaterThanTyper[T](op: Context => T): T =
    atPhaseNotLaterThan(base.typerPhase)(op)
}

object Phases {

  trait PhasesBase {
    this: ContextBase =>

    // drop NoPhase at beginning
    def allPhases = squashedPhases.tail

    object NoPhase extends Phase {
      override def exists = false
      def name = "<no phase>"
      def run(implicit ctx: Context): Unit = unsupported("run")
      def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = unsupported("transform")
    }

    object SomePhase extends Phase {
      def name = "<some phase>"
      def run(implicit ctx: Context): Unit = unsupported("run")
    }

    /** A sentinel transformer object */
    class TerminalPhase extends DenotTransformer {
      def name = "terminal"
      def run(implicit ctx: Context): Unit = unsupported("run")
      def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation =
        unsupported("transform")
      override def lastPhaseId(implicit ctx: Context) = id
    }

    /** Use the following phases in the order they are given.
     *  The list should never contain NoPhase.
     *  if squashing is enabled, phases in same subgroup will be squashed to single phase.
     */
    def usePhases(phasess: List[List[Phase]], squash: Boolean = false) = {
      phases = (NoPhase :: phasess.flatten ::: new TerminalPhase :: Nil).toArray
      nextDenotTransformerId = new Array[Int](phases.length)
      denotTransformers = new Array[DenotTransformer](phases.length)
      var i = 0
      while (i < phases.length) {
        phases(i).init(this, i)
        i += 1
      }
      var lastTransformerId = i
      while (i > 0) {
        i -= 1
        phases(i) match {
          case transformer: DenotTransformer =>
            lastTransformerId = i
            denotTransformers(i) = transformer
          case _ =>
        }
        nextDenotTransformerId(i) = lastTransformerId
      }

      if (squash) {
        val squashedPhases = ListBuffer[Phase]()
        var postTyperEmmited = false
        var i = 0
        while (i < phasess.length) {
          if (phasess(i).length > 1) {
            assert(phasess(i).forall(x => x.isInstanceOf[TreeTransform]), "Only tree transforms can be squashed")

            val transforms = phasess(i).asInstanceOf[List[TreeTransform]]
            val block =
              if (!postTyperEmmited) {
                postTyperEmmited = true
                new PostTyperTransformer {
                  override def name: String = transformations.map(_.name).mkString("TreeTransform:{", ", ", "}")
                  override protected def transformations: Array[TreeTransform] = transforms.toArray
                }
              } else new TreeTransformer {
                override def name: String = transformations.map(_.name).mkString("TreeTransform:{", ", ", "}")
                override protected def transformations: Array[TreeTransform] = transforms.toArray
              }
            squashedPhases += block
            block.init(this, phasess(i).head.id)
          } else squashedPhases += phasess(i).head
          i += 1
        }
        this.squashedPhases = (NoPhase::squashedPhases.toList :::new TerminalPhase :: Nil).toArray
      } else {
        this.squashedPhases = this.phases
      }

      config.println(s"Phases = ${phases.deep}")
      config.println(s"squashedPhases = ${squashedPhases.deep}")
      config.println(s"nextDenotTransformerId = ${nextDenotTransformerId.deep}")
    }

    def phaseNamed(name: String) = phases.find(_.name == name).getOrElse(NoPhase)

    /** A cache to compute the phase with given name, which
     *  stores the phase as soon as phaseNamed returns something
     *  different from NoPhase.
     */
    private class PhaseCache(name: String) {
      private var myPhase: Phase = NoPhase
      def phase = {
        if (myPhase eq NoPhase) myPhase = phaseNamed(name)
        myPhase
      }
    }

    private val typerCache = new PhaseCache(typerName)
    private val refChecksCache = new PhaseCache(refChecksName)
    private val erasureCache = new PhaseCache(erasureName)
    private val flattenCache = new PhaseCache(flattenName)

    def typerPhase = typerCache.phase
    def refchecksPhase = refChecksCache.phase
    def erasurePhase = erasureCache.phase
    def flattenPhase = flattenCache.phase
  }

  final val typerName = "typer"
  final val refChecksName = "refchecks"
  final val erasureName = "erasure"
  final val flattenName = "flatten"

  abstract class Phase extends DotClass {

    def name: String

    def run(implicit ctx: Context): Unit

    def runOn(units: List[CompilationUnit])(implicit ctx: Context): Unit =
      for (unit <- units) run(ctx.fresh.withPhase(this).withCompilationUnit(unit))

    def description: String = name

    def checkable: Boolean = true

    def exists: Boolean = true

    private var myId: PhaseId = -1
    private var myBase: ContextBase = null
    private var myErasedTypes = false
    private var myFlatClasses = false
    private var myRefChecked = false

    /** The sequence position of this phase in the given context where 0
     * is reserved for NoPhase and the first real phase is at position 1.
     * -1 if the phase is not installed in the context.
     */
    def id = myId

    final def erasedTypes = myErasedTypes
    final def flatClasses = myFlatClasses
    final def refChecked = myRefChecked

    protected[Phases] def init(base: ContextBase, id: Int): Unit = {
      if (id >= FirstPhaseId)
        assert(myId == -1, s"phase $this has already been used once; cannot be reused")
      myBase = base
      myId = id
      myErasedTypes = prev.name == erasureName   || prev.erasedTypes
      myFlatClasses = prev.name == flattenName   || prev.flatClasses
      myRefChecked  = prev.name == refChecksName || prev.refChecked
    }

    final def <=(that: Phase)(implicit ctx: Context) =
      exists && id <= that.id

    final def prev: Phase =
      if (id > FirstPhaseId) myBase.phases(id - 1) else myBase.NoPhase

    final def next: Phase =
      if (hasNext) myBase.phases(id + 1) else myBase.NoPhase

    final def hasNext = id >= FirstPhaseId && id + 1 < myBase.phases.length

    final def iterator =
      Iterator.iterate(this)(_.next) takeWhile (_.hasNext)

    override def toString = name
  }
}