aboutsummaryrefslogtreecommitdiff
path: root/test/test/DeSugarTest.scala
blob: 09d97872b944bf2291339a80e56ca93d78edff19 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package test

import scala.reflect.io._
import dotty.tools.dotc.util._
import dotty.tools.dotc.core._
import dotty.tools.dotc.parsing._
import Tokens._, Parsers._
import org.junit.Test
import dotty.tools.dotc._
import ast.Trees._
import ast.desugar
import ast.desugar._
import typer.Mode

import scala.collection.mutable.ListBuffer

class DeSugarTest extends ParserTest {

  import dotty.tools.dotc.ast.untpd._

  import Mode._

  object DeSugar extends TreeTransformer {
    var curMode: Mode = Mode.Expr
    def withMode[T](mode: Mode)(op: => T) = {
      val saved = curMode
      curMode = mode
      try op
      finally curMode = saved
    }

    def transform(tree: Tree, mode: Mode): Tree = withMode(mode) { transform(tree) }
    def transform(trees: List[Tree], mode: Mode): List[Tree] = withMode(mode) { transform(trees) }

    override def transform(tree: Tree): Tree = {
      val tree1 = desugar(tree, curMode)
      tree1 match {
        case TypedSplice(t) =>
          tree1
        case PostfixOp(od, op) =>
          PostfixOp(transform(od), op)
        case Select(qual, name) =>
          tree1.derivedSelect(transform(qual, Expr), name)
        case Apply(fn, args) =>
          tree1.derivedApply(transform(fn, Expr), transform(args))
        case TypeApply(fn, args) =>
          tree1.derivedTypeApply(transform(fn, Expr), transform(args, Type))
        case New(tpt) =>
          tree1.derivedNew(transform(tpt, Type))
        case Typed(expr, tpt) =>
          tree1.derivedTyped(transform(expr), transform(tpt, Type))
        case CaseDef(pat, guard, body) =>
          tree1.derivedCaseDef(transform(pat, Pattern), transform(guard), transform(body))
        case SeqLiteral(elempt, elems) =>
          tree1.derivedSeqLiteral(transform(elempt, Type), transform(elems))
        case UnApply(fun, args) =>
          tree1.derivedUnApply(transform(fun, Expr), transform(args))
        case ValDef(mods, name, tpt, rhs) =>
          tree1.derivedValDef(mods, name, transform(tpt, Type), transform(rhs))
        case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
          tree1.derivedDefDef(mods, name, transformSub(tparams), vparamss mapConserve (transformSub(_)), transform(tpt, Type), transform(rhs))
        case tree1 @ TypeDef(mods, name, rhs) =>
          tree1.derivedTypeDef(mods, name, transform(rhs, Type), transformSub(tree1.tparams))
        case Template(constr, parents, self, body) =>
          tree1.derivedTemplate(transformSub(constr), transform(parents), transformSub(self), transform(body, Expr))
        case Thicket(trees) =>
          Thicket(flatten(trees mapConserve super.transform))
        case tree1 =>
          super.transform(tree1)
      }
    }
  }

  def firstClass(stats: List[Tree]): String = stats match {
    case Nil => "<empty>"
    case TypeDef(_, name, _) :: _ => name.toString
    case ModuleDef(_, name, _) :: _ => name.toString
    case (pdef: PackageDef) :: _ => firstClass(pdef)
    case stat :: stats => firstClass(stats)
  }

  def firstClass(tree: Tree): String = tree match {
    case PackageDef(pid, stats) =>
      pid.show + "." + firstClass(stats)
    case _ => "??? "+tree.getClass
  }

  def desugarTree(tree: Tree): Tree = {
    //println("***** desugaring "+firstClass(tree))
    DeSugar.transform(tree)
  }

  def desugarAll() = parsedTrees foreach (desugarTree(_).show)
}