aboutsummaryrefslogtreecommitdiff
path: root/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala
blob: 9634decaa166255cac9d34288e6e348bc7ddeee3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package dotty.tools.dotc
package transform

import core._
import typer._
import Phases._
import ast.Trees._
import Contexts._
import Symbols._
import Flags.PackageVal
import Decorators._

/** A base class for transforms.
 *  A transform contains a compiler phase which applies a tree transformer.
 */
abstract class MacroTransform extends Phase {

  import ast.tpd._

  override def run(implicit ctx: Context): Unit = {
    val unit = ctx.compilationUnit
    unit.tpdTree = newTransformer.transform(unit.tpdTree)(ctx.withPhase(transformPhase))
  }

  protected def newTransformer(implicit ctx: Context): Transformer

  /** The phase in which the transformation should be run.
   *  By default this is the phase given by the this macro transformer,
   *  but it could be overridden to be the phase following that one.
   */
  protected def transformPhase(implicit ctx: Context): Phase = this

  class Transformer extends TreeMap {

    protected def localCtx(tree: Tree)(implicit ctx: Context) = {
      val sym = tree.symbol
      val owner = if (sym is PackageVal) sym.moduleClass else sym
      ctx.fresh.setTree(tree).setOwner(owner)
    }

    def transformStats(trees: List[Tree], exprOwner: Symbol)(implicit ctx: Context): List[Tree] = {
      def transformStat(stat: Tree): Tree = stat match {
        case _: Import | _: DefTree => transform(stat)
        case Thicket(stats) => cpy.Thicket(stat)(stats mapConserve transformStat)
        case _ => transform(stat)(ctx.exprContext(stat, exprOwner))
      }
      flatten(trees.mapconserve(transformStat(_)))
    }

    override def transform(tree: Tree)(implicit ctx: Context): Tree = {
      tree match {
        case EmptyValDef =>
          tree
        case _: PackageDef | _: MemberDef =>
          super.transform(tree)(localCtx(tree))
        case impl @ Template(constr, parents, self, _) =>
          cpy.Template(tree)(
            transformSub(constr),
            transform(parents)(ctx.superCallContext),
            transformSelf(self),
            transformStats(impl.body, tree.symbol))
        case _ =>
          super.transform(tree)
      }
    }

    def transformSelf(vd: ValDef)(implicit ctx: Context) =
      cpy.ValDef(vd)(tpt = transform(vd.tpt))
  }
}