summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/ast/TreeDSL.scala
blob: fb65f827404f279d330061499df350d81d628463 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
/* NSC -- new Scala compiler
 * Copyright 2005-2010 LAMP/EPFL
 *
 * @author  Paul Phillips
 */

package scala.tools.nsc
package ast

/** A DSL for generating scala code.  The goal is that the
 *  code generating code should look a lot like the code it
 *  generates.
 */

trait TreeDSL {
  val global: Global

  import global._
  import definitions._
  import gen.{ scalaDot }
  import PartialFunction._

  object CODE {
    // Add a null check to a Tree => Tree function
    def nullSafe[T](f: Tree => Tree, ifNull: Tree): Tree => Tree =
      tree => IF (tree MEMBER_== NULL) THEN ifNull ELSE f(tree)

    // Applies a function to a value and then returns the value.
    def returning[T](f: T => Unit)(x: T): T = { f(x) ; x }

    // strip bindings to find what lies beneath
    final def unbind(x: Tree): Tree = x match {
      case Bind(_, y) => unbind(y)
      case y          => y
    }

    object LIT extends (Any => Literal) {
      def apply(x: Any)   = Literal(Constant(x))
      def unapply(x: Any) = condOpt(x) { case Literal(Constant(value)) => value }
    }

    // You might think these could all be vals, but empirically I have found that
    // at least in the case of UNIT the compiler breaks if you re-use trees.
    // However we need stable identifiers to have attractive pattern matching.
    // So it's inconsistent until I devise a better way.
    val TRUE          = LIT(true)
    val FALSE         = LIT(false)
    val ZERO          = LIT(0)
    def NULL          = LIT(null)
    def UNIT          = LIT(())

    object WILD {
      def apply(tpe: Type = null) =
        if (tpe == null) Ident(nme.WILDCARD)
        else Ident(nme.WILDCARD) setType tpe

      def unapply(other: Any) =
        cond(other) { case Ident(nme.WILDCARD)  => true }
    }

    def fn(lhs: Tree, op:   Name, args: Tree*)  = Apply(Select(lhs, op), args.toList)
    def fn(lhs: Tree, op: Symbol, args: Tree*)  = Apply(Select(lhs, op), args.toList)

    class TreeMethods(target: Tree) {
      private def toAnyRef(x: Tree) = x setType AnyRefClass.tpe

      /** logical/comparison ops **/
      def OR(other: Tree) =
        if (target == EmptyTree) other
        else if (other == EmptyTree) target
        else gen.mkOr(target, other)

      def AND(other: Tree) =
        if (target == EmptyTree) other
        else if (other == EmptyTree) target
        else gen.mkAnd(target, other)

      /** Note - calling ANY_== in the matcher caused primitives to get boxed
       *  for the comparison, whereas looking up nme.EQ does not.
       */
      def MEMBER_== (other: Tree)   = {
        if (target.tpe == null) ANY_==(other)
        else fn(target, target.tpe member nme.EQ, other)
      }
      def ANY_NE  (other: Tree)     = fn(target, nme.ne, toAnyRef(other))
      def ANY_EQ  (other: Tree)     = fn(target, nme.eq, toAnyRef(other))
      def ANY_==  (other: Tree)     = fn(target, Any_==, other)
      def ANY_>=  (other: Tree)     = fn(target, nme.GE, other)
      def ANY_<=  (other: Tree)     = fn(target, nme.LE, other)
      def OBJ_!=  (other: Tree)     = fn(target, Object_ne, other)

      def INT_|   (other: Tree)     = fn(target, getMember(IntClass, nme.OR), other)
      def INT_&   (other: Tree)     = fn(target, getMember(IntClass, nme.AND), other)
      def INT_==  (other: Tree)     = fn(target, getMember(IntClass, nme.EQ), other)
      def INT_!=  (other: Tree)     = fn(target, getMember(IntClass, nme.NE), other)

      def BOOL_&& (other: Tree)     = fn(target, getMember(BooleanClass, nme.ZAND), other)
      def BOOL_|| (other: Tree)     = fn(target, getMember(BooleanClass, nme.ZOR), other)

      /** Apply, Select, Match **/
      def APPLY(params: Tree*)      = Apply(target, params.toList)
      def APPLY(params: List[Tree]) = Apply(target, params)
      def MATCH(cases: CaseDef*)    = Match(target, cases.toList)

      def DOT(member: Name)         = SelectStart(Select(target, member))
      def DOT(sym: Symbol)          = SelectStart(Select(target, sym))

      /** Assignment */
      def ===(rhs: Tree)            = Assign(target, rhs)

      /** Methods for sequences **/
      def DROP(count: Int): Tree =
        if (count == 0) target
        else (target DOT nme.drop)(LIT(count))

      /** Casting & type tests -- working our way toward understanding exactly
       *  what differs between the different forms of IS and AS.
       *
       *  See ticket #2168 for one illustration of AS vs. AS_ANY.
       */
      def AS(tpe: Type)       = TypeApply(Select(target, Any_asInstanceOf), List(TypeTree(tpe)))
      def AS_ANY(tpe: Type)   = gen.mkAsInstanceOf(target, tpe)
      def AS_ATTR(tpe: Type)  = gen.mkAttributedCast(target, tpe)

      def IS(tpe: Type)       = gen.mkIsInstanceOf(target, tpe, true)
      def IS_OBJ(tpe: Type)   = gen.mkIsInstanceOf(target, tpe, false)

      // XXX having some difficulty expressing nullSafe in a way that doesn't freak out value types
      // def TOSTRING()          = nullSafe(fn(_: Tree, nme.toString_), LIT("null"))(target)
      def TOSTRING()          = fn(target, nme.toString_)
      def GETCLASS()          = fn(target, Object_getClass)
    }

    case class SelectStart(tree: Select) {
      def apply(args: Tree*) = Apply(tree, args.toList)
    }

    class CaseStart(pat: Tree, guard: Tree) {
      def IF(g: Tree): CaseStart    = new CaseStart(pat, g)
      def ==>(body: Tree): CaseDef  = CaseDef(pat, guard, body)
    }

    abstract class ValOrDefStart(sym: Symbol) {
      def ===(body: Tree): ValOrDefDef
    }
    class DefStart(sym: Symbol) extends ValOrDefStart(sym) {
      def ===(body: Tree) = DefDef(sym, body)
    }
    class ValStart(sym: Symbol) extends ValOrDefStart(sym) {
      def ===(body: Tree) = ValDef(sym, body)
    }
    class IfStart(cond: Tree, thenp: Tree) {
      def THEN(x: Tree) = new IfStart(cond, x)
      def ELSE(elsep: Tree) = If(cond, thenp, elsep)
      def ENDIF = If(cond, thenp, EmptyTree)
    }
    class TryStart(body: Tree, catches: List[CaseDef], fin: Tree) {
      def CATCH(xs: CaseDef*) = new TryStart(body, xs.toList, fin)
      def FINALLY(x: Tree)    = Try(body, catches, x)
      def ENDTRY              = Try(body, catches, fin)
    }

    def CASE(pat: Tree): CaseStart  = new CaseStart(pat, EmptyTree)
    def DEFAULT: CaseStart          = new CaseStart(WILD(), EmptyTree)

    class NameMethods(target: Name) {
      def BIND(body: Tree) = Bind(target, body)
    }

    class SymbolMethods(target: Symbol) {
      def BIND(body: Tree) = Bind(target, body)

      // Option
      def IS_DEFINED() =
        if (target.tpe.typeSymbol == SomeClass) TRUE   // is Some[_]
        else NOT(ID(target) DOT nme.isEmpty)           // is Option[_]

      def GET() = fn(ID(target), nme.get)

      // name of nth indexed argument to a method (first parameter list), defaults to 1st
      def ARG(idx: Int = 0) = Ident(target.paramss.head(idx))
      def ARGS = target.paramss.head
      def ARGNAMES = ARGS map Ident
    }

    /** Top level accessible. */
    def THROW(sym: Symbol, msg: Tree = null) = {
      val arg: List[Tree] = if (msg == null) Nil else List(msg.TOSTRING())
      Throw(New(TypeTree(sym.tpe), List(arg)))
    }
    def NEW(tpe: Tree, args: Tree*)   = New(tpe, List(args.toList))
    def NEW(sym: Symbol, args: Tree*) =
      if (args.isEmpty) New(TypeTree(sym.tpe))
      else New(TypeTree(sym.tpe), List(args.toList))

    def VAL(sym: Symbol) = new ValStart(sym)
    def DEF(sym: Symbol) = new DefStart(sym)
    def AND(guards: Tree*) =
      if (guards.isEmpty) EmptyTree
      else guards reduceLeft gen.mkAnd

    def OR(guards: Tree*) =
      if (guards.isEmpty) EmptyTree
      else guards reduceLeft gen.mkOr

    def IF(tree: Tree)    = new IfStart(tree, EmptyTree)
    def TRY(tree: Tree)   = new TryStart(tree, Nil, EmptyTree)
    def BLOCK(xs: Tree*)  = Block(xs.init.toList, xs.last)
    def NOT(tree: Tree)   = Select(tree, getMember(BooleanClass, nme.UNARY_!))

    private val _SOME     = scalaDot(nme.Some)
    def SOME(xs: Tree*)   = Apply(_SOME, List(makeTupleTerm(xs.toList, true)))

    /** Typed trees from symbols. */
    def THIS(sym: Symbol)             = gen.mkAttributedThis(sym)
    def ID(sym: Symbol)               = gen.mkAttributedIdent(sym)
    def REF(sym: Symbol)              = gen.mkAttributedRef(sym)
    def REF(pre: Type, sym: Symbol)   = gen.mkAttributedRef(pre, sym)

    /** Some of this is basically verbatim from TreeBuilder, but we do not want
     *  to get involved with him because he's an untyped only sort.
     */
    private def tupleName(count: Int, f: (String) => Name = newTermName(_: String)) =
      scalaDot(f("Tuple" + count))

    def makeTupleTerm(trees: List[Tree], flattenUnary: Boolean): Tree = trees match {
      case Nil                        => UNIT
      case List(tree) if flattenUnary => tree
      case _                          => Apply(tupleName(trees.length), trees)
    }
    def makeTupleType(trees: List[Tree], flattenUnary: Boolean): Tree = trees match {
      case Nil                        => gen.scalaUnitConstr
      case List(tree) if flattenUnary => tree
      case _                          => AppliedTypeTree(tupleName(trees.length, newTypeName), trees)
    }

    /** Implicits - some of these should probably disappear **/
    implicit def mkTreeMethods(target: Tree): TreeMethods = new TreeMethods(target)
    implicit def mkTreeMethodsFromSymbol(target: Symbol): TreeMethods = new TreeMethods(Ident(target))
    implicit def mkTreeMethodsFromName(target: Name): TreeMethods = new TreeMethods(Ident(target))
    implicit def mkTreeMethodsFromString(target: String): TreeMethods = new TreeMethods(Ident(target))

    implicit def mkNameMethodsFromName(target: Name): NameMethods = new NameMethods(target)
    implicit def mkNameMethodsFromString(target: String): NameMethods = new NameMethods(target)

    implicit def mkSymbolMethodsFromSymbol(target: Symbol): SymbolMethods = new SymbolMethods(target)

    /** (foo DOT bar) might be simply a Select, but more likely it is to be immediately
     *  followed by an Apply.  We don't want to add an actual apply method to arbitrary
     *  trees, so SelectStart is created with an apply - and if apply is not the next
     *  thing called, the implicit from SelectStart -> Tree will provide the tree.
     */
    implicit def mkTreeFromSelectStart(ss: SelectStart): Select = ss.tree
    implicit def mkTreeMethodsFromSelectStart(ss: SelectStart): TreeMethods = mkTreeMethods(ss.tree)
  }
}