summaryrefslogtreecommitdiff
path: root/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
blob: 0210ad3459ac37b9097717d9deef635e4a269bf8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
// $Id$

package scala.tools.selectivecps

import scala.tools.nsc.transform._
import scala.tools.nsc.plugins._
import scala.tools.nsc.ast._

/**
 * In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase.
 */
abstract class SelectiveCPSTransform extends PluginComponent with
  InfoTransform with TypingTransformers with CPSUtils with TreeDSL {
  // inherits abstract value `global` and class `Phase` from Transform

  import global._                  // the global environment
  import definitions._             // standard classes and methods
  import typer.atOwner             // methods to type trees

  override def description = "@cps-driven transform of selectiveanf assignments"

  /** the following two members override abstract members in Transform */
  val phaseName: String = "selectivecps"

  protected def newTransformer(unit: CompilationUnit): Transformer =
    new CPSTransformer(unit)

  /** This class does not change linearization */
  override def changesBaseClasses = false

  /** - return symbol's transformed type,
   */
  def transformInfo(sym: Symbol, tp: Type): Type = {
    if (!cpsEnabled) return tp

    val newtp = transformCPSType(tp)

    if (newtp != tp)
      debuglog("transformInfo changed type for " + sym + " to " + newtp);

    if (sym == MethReifyR)
      debuglog("transformInfo (not)changed type for " + sym + " to " + newtp);

    newtp
  }

  def transformCPSType(tp: Type): Type = {  // TODO: use a TypeMap? need to handle more cases?
    tp match {
      case PolyType(params,res) => PolyType(params, transformCPSType(res))
      case NullaryMethodType(res) => NullaryMethodType(transformCPSType(res))
      case MethodType(params,res) => MethodType(params, transformCPSType(res))
      case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_)))
      case _ =>
        getExternalAnswerTypeAnn(tp) match {
          case Some((res, outer)) =>
            appliedType(Context.tpeHK, List(removeAllCPSAnnotations(tp), res, outer))
          case _ =>
            removeAllCPSAnnotations(tp)
        }
    }
  }


  class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
    private val patmatTransformer = patmat.newTransformer(unit)

    override def transform(tree: Tree): Tree = {
      if (!cpsEnabled) return tree
      postTransform(mainTransform(tree))
    }

    def postTransform(tree: Tree): Tree = {
      tree.setType(transformCPSType(tree.tpe))
    }


    def mainTransform(tree: Tree): Tree = {
      tree match {

        // TODO: can we generalize this?

        case Apply(TypeApply(fun, targs), args)
        if (fun.symbol == MethShift) =>
          debuglog("found shift: " + tree)
          atPos(tree.pos) {
            val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct?
            //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage),
            //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR)
            //gen.mkAttributedRef(ModCPS.tpe,  MethShiftR) // TODO: correct?
            debuglog("funR.tpe: " + funR.tpe)
            Apply(
                TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
                args.map(transform(_))
            ).setType(transformCPSType(tree.tpe))
          }

        case Apply(TypeApply(fun, targs), args)
        if (fun.symbol == MethShiftUnit) =>
          debuglog("found shiftUnit: " + tree)
          atPos(tree.pos) {
            val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct?
            debuglog("funR.tpe: " + funR.tpe)
            Apply(
                TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe,
                    List(targs(0).tpe, targs(1).tpe))),
                args.map(transform(_))
            ).setType(appliedType(Context.tpeHK, List(targs(0).tpe,targs(1).tpe,targs(1).tpe)))
          }

        case Apply(TypeApply(fun, targs), args)
        if (fun.symbol == MethReify) =>
          log("found reify: " + tree)
          atPos(tree.pos) {
            val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct?
            debuglog("funR.tpe: " + funR.tpe)
            Apply(
                TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
                args.map(transform(_))
            ).setType(transformCPSType(tree.tpe))
          }

      case Try(block, catches, finalizer) =>
        // currently duplicates the catch block into a partial function.
        // this is kinda risky, but we don't expect there will be lots
        // of try/catches inside catch blocks (exp. blowup unlikely).

        // CAVEAT: finalizers are surprisingly tricky!
        // the problem is that they cannot easily be removed
        // from the regular control path and hence will
        // also be invoked after creating the Context object.

        /*
        object Test {
          def foo1 = {
            throw new Exception("in sub")
            shift((k:Int=>Int) => k(1))
            10
          }
          def foo2 = {
            shift((k:Int=>Int) => k(2))
            20
          }
          def foo3 = {
            shift((k:Int=>Int) => k(3))
            throw new Exception("in sub")
            30
          }
          def foo4 = {
            shift((k:Int=>Int) => 4)
            throw new Exception("in sub")
            40
          }
          def bar(x: Int) = try {
            if (x == 1)
              foo1
            else if (x == 2)
              foo2
            else if (x == 3)
              foo3
            else //if (x == 4)
              foo4
          } catch {
            case _ =>
              println("exception")
              0
          } finally {
            println("done")
          }
        }

        reset(Test.bar(1)) // should print: exception,done,0
        reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20
        reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0
        reset(Test.bar(4)) // should print: 4 <-- but prints: done,4
        */

        val block1 = transform(block)
        val catches1 = transformCaseDefs(catches)
        val finalizer1 = transform(finalizer)

        if (hasAnswerTypeAnn(tree.tpe)) {
          //vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe)

          val (stms, expr1) = block1 match {
            case Block(stms, expr) => (stms, expr)
            case expr => (Nil, expr)
          }

          val targettp = transformCPSType(tree.tpe)

          val pos = catches.head.pos
          val funSym = currentOwner.newValueParameter(cpsNames.catches, pos).setInfo(appliedType(PartialFunctionClass, ThrowableTpe, targettp))
          val funDef = localTyper.typedPos(pos) {
            ValDef(funSym, Match(EmptyTree, catches1))
          }
          val expr2 = localTyper.typedPos(pos) {
            Apply(Select(expr1, expr1.tpe.member(cpsNames.flatMapCatch)), List(Ident(funSym)))
          }

          val exSym = currentOwner.newValueParameter(cpsNames.ex, pos).setInfo(ThrowableTpe)

          import CODE._
          // generate a case that is supported directly by the back-end
          val catchIfDefined = CaseDef(
                Bind(exSym, Ident(nme.WILDCARD)),
                EmptyTree,
                IF ((REF(funSym) DOT nme.isDefinedAt)(REF(exSym))) THEN (REF(funSym) APPLY (REF(exSym))) ELSE Throw(REF(exSym))
              )

          val catch2 = localTyper.typedCases(List(catchIfDefined), ThrowableTpe, targettp)
          //typedCases(tree, catches, ThrowableTpe, pt)

          patmatTransformer.transform(localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1))))


/*
          disabled for now - see notes above

          val expr3 = if (!finalizer.isEmpty) {
            val pos = finalizer.pos
            val finalizer2 = duplicateTree(finalizer1)
            val fun = Function(List(), finalizer2)
            val expr3 = localTyper.typedPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) }

            val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
            chown.traverse(finalizer2)

            expr3
          } else
            expr2
*/
        } else {
          treeCopy.Try(tree, block1, catches1, finalizer1)
        }

      case Block(stms, expr) =>

          val (stms1, expr1) = transBlock(stms, expr)
          treeCopy.Block(tree, stms1, expr1)

        case _ =>
          super.transform(tree)
      }
    }



    def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = {

      stms match {
        case Nil =>
          (Nil, transform(expr))

        case stm::rest =>

          stm match {
            case vd @ ValDef(mods, name, tpt, rhs)
            if (vd.symbol.hasAnnotation(MarkerCPSSym)) =>

              debuglog("found marked ValDef "+name+" of type " + vd.symbol.tpe)

              val tpe = vd.symbol.tpe
              val rhs1 = atOwner(vd.symbol) { transform(rhs) }
              rhs1.changeOwner(vd.symbol -> currentOwner) // TODO: don't traverse twice

              debuglog("valdef symbol " + vd.symbol + " has type " + tpe)
              debuglog("right hand side " + rhs1 + " has type " + rhs1.tpe)

              debuglog("currentOwner: " + currentOwner)
              debuglog("currentMethod: " + currentMethod)

              val (bodyStms, bodyExpr) = transBlock(rest, expr)
              // FIXME: result will later be traversed again by TreeSymSubstituter and
              // ChangeOwnerTraverser => exp. running time.
              // Should be changed to fuse traversals into one.

              val specialCaseTrivial = bodyExpr match {
                case Apply(fun, args) =>
                  // for now, look for explicit tail calls only.
                  // are there other cases that could profit from specializing on
                  // trivial contexts as well?
                  (bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol)
                case _ => false
              }

              def applyTrivial(ctxValSym: Symbol, body: Tree) = {

                val body1 = (new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)))(body)

                val body2 = localTyper.typedPos(vd.symbol.pos) { body1 }

                // in theory it would be nicer to look for an @cps annotation instead
                // of testing for Context
                if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) {
                  //println(body2 + "/" + body2.tpe)
                  unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")
                }
                body2
              }

              def applyCombinatorFun(ctxR: Tree, body: Tree) = {
                val arg = currentOwner.newValueParameter(name, ctxR.pos).setInfo(tpe)
                val body1 = (new TreeSymSubstituter(List(vd.symbol), List(arg)))(body)
                val fun = localTyper.typedPos(vd.symbol.pos) { Function(List(ValDef(arg)), body1) } // types body as well
                arg.owner = fun.symbol
                body1.changeOwner(currentOwner -> fun.symbol)

                // see note about multiple traversals above

                debuglog("fun.symbol: "+fun.symbol)
                debuglog("fun.symbol.owner: "+fun.symbol.owner)
                debuglog("arg.owner: "+arg.owner)

                debuglog("fun.tpe:"+fun.tpe)
                debuglog("return type of fun:"+body1.tpe)

                var methodName = nme.map

                if (body1.tpe != null) {
                  if (body1.tpe.typeSymbol == Context)
                    methodName = nme.flatMap
                }
                else
                  unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")

                debuglog("will use method:"+methodName)

                localTyper.typedPos(vd.symbol.pos) {
                  Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun))
                }
              }

              // TODO use gen.mkBlock after 2.11.0-M6. Why wait? It allows us to still build in development
              // mode with `ant -DskipLocker=1`
              def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr

              try {
                if (specialCaseTrivial) {
                  debuglog("will optimize possible tail call: " + bodyExpr)

                  // FIXME: flatMap impl has become more complicated due to
                  // exceptions. do we need to put a try/catch in the then part??

                  // val ctx = <rhs>
                  // if (ctx.isTrivial)
                  //   val <lhs> = ctx.getTrivialValue; ...    <--- TODO: try/catch ??? don't bother for the moment...
                  // else
                  //   ctx.flatMap { <lhs> => ... }
                  val ctxSym = currentOwner.newValue(newTermName("" + vd.symbol.name + cpsNames.shiftSuffix)).setInfo(rhs1.tpe)
                  val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1))
                  def ctxRef = localTyper.typed(Ident(ctxSym))
                  val argSym = currentOwner.newValue(vd.symbol.name.toTermName).setInfo(tpe)
                  val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member(cpsNames.getTrivialValue))))
                  val switchExpr = localTyper.typedPos(vd.symbol.pos) {
                    val body2 = mkBlock(bodyStms, bodyExpr).duplicate // dup before typing!
                    If(Select(ctxRef, ctxSym.tpe.member(cpsNames.isTrivial)),
                      applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)),
                      applyCombinatorFun(ctxRef, body2))
                  }
                  (List(ctxDef), switchExpr)
                } else {
                  // ctx.flatMap { <lhs> => ... }
                  //     or
                  // ctx.map { <lhs> => ... }
                  (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr)))
                }
              } catch {
                case ex:TypeError =>
                  unit.error(ex.pos, ex.msg)
                  (bodyStms, bodyExpr)
              }

            case _ =>
                val stm1 = transform(stm)
                val (a, b) = transBlock(rest, expr)
                (stm1::a, b)
            }
      }
    }


  }
}