aboutsummaryrefslogtreecommitdiff
path: root/compiler/src/dotty/tools/dotc/transform/NonLocalReturns.scala
blob: 945504743c3f36607862391fd0e8141da4ac3e3d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package dotty.tools.dotc
package transform

import core._
import Contexts._, Symbols._, Types._, Flags._, Decorators._, StdNames._, Constants._, Phases._
import TreeTransforms._
import ast.Trees._
import NameExtractors.NonLocalReturnKeyName
import collection.mutable

object NonLocalReturns {
  import ast.tpd._
  def isNonLocalReturn(ret: Return)(implicit ctx: Context) =
    ret.from.symbol != ctx.owner.enclosingMethod || ctx.owner.is(Lazy)
}

/** Implement non-local returns using NonLocalReturnControl exceptions.
 */
class NonLocalReturns extends MiniPhaseTransform { thisTransformer =>
  override def phaseName = "nonLocalReturns"

  import NonLocalReturns._
  import ast.tpd._

  override def runsAfter: Set[Class[_ <: Phase]] = Set(classOf[ElimByName])

  private def ensureConforms(tree: Tree, pt: Type)(implicit ctx: Context) =
    if (tree.tpe <:< pt) tree
    else Erasure.Boxing.adaptToType(tree, pt)

  /** The type of a non-local return expression with given argument type */
  private def nonLocalReturnExceptionType(argtype: Type)(implicit ctx: Context) =
    defn.NonLocalReturnControlType.appliedTo(argtype)

  /** A hashmap from method symbols to non-local return keys */
  private val nonLocalReturnKeys = mutable.Map[Symbol, TermSymbol]()

  /** Return non-local return key for given method */
  private def nonLocalReturnKey(meth: Symbol)(implicit ctx: Context) =
    nonLocalReturnKeys.getOrElseUpdate(meth,
      ctx.newSymbol(
        meth, NonLocalReturnKeyName.fresh(), Synthetic, defn.ObjectType, coord = meth.pos))

  /** Generate a non-local return throw with given return expression from given method.
   *  I.e. for the method's non-local return key, generate:
   *
   *    throw new NonLocalReturnControl(key, expr)
   *  todo: maybe clone a pre-existing exception instead?
   *  (but what to do about exceptions that miss their targets?)
   */
  private def nonLocalReturnThrow(expr: Tree, meth: Symbol)(implicit ctx: Context) =
    Throw(
      New(
        defn.NonLocalReturnControlType,
        ref(nonLocalReturnKey(meth)) :: expr.ensureConforms(defn.ObjectType) :: Nil))

  /** Transform (body, key) to:
   *
   *  {
   *    val key = new Object()
   *    try {
   *      body
   *    } catch {
   *      case ex: NonLocalReturnControl =>
   *        if (ex.key().eq(key)) ex.value().asInstanceOf[T]
   *        else throw ex
   *    }
   *  }
   */
  private def nonLocalReturnTry(body: Tree, key: TermSymbol, meth: Symbol)(implicit ctx: Context) = {
    val keyDef = ValDef(key, New(defn.ObjectType, Nil))
    val nonLocalReturnControl = defn.NonLocalReturnControlType
    val ex = ctx.newSymbol(meth, nme.ex, EmptyFlags, nonLocalReturnControl, coord = body.pos)
    val pat = BindTyped(ex, nonLocalReturnControl)
    val rhs = If(
        ref(ex).select(nme.key).appliedToNone.select(nme.eq).appliedTo(ref(key)),
        ref(ex).select(nme.value).ensureConforms(meth.info.finalResultType),
        Throw(ref(ex)))
    val catches = CaseDef(pat, EmptyTree, rhs) :: Nil
    val tryCatch = Try(body, catches, EmptyTree)
    Block(keyDef :: Nil, tryCatch)
  }

  override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree =
    nonLocalReturnKeys.remove(tree.symbol) match {
      case Some(key) => cpy.DefDef(tree)(rhs = nonLocalReturnTry(tree.rhs, key, tree.symbol))
      case _ => tree
    }

  override def transformReturn(tree: Return)(implicit ctx: Context, info: TransformerInfo): Tree =
    if (isNonLocalReturn(tree)) nonLocalReturnThrow(tree.expr, tree.from.symbol).withPos(tree.pos)
    else tree
}