1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
package scala.async
import scala.reflect.macros.Context
class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
import c.universe._
object inline {
def transformToList(tree: Tree): List[Tree] = {
val stats :+ expr = anf.transformToList(tree)
expr match {
case Apply(fun, args) if isAwait(fun) =>
val valDef = defineVal("await", expr)
stats :+ valDef :+ Ident(valDef.name)
case If(cond, thenp, elsep) =>
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
} else {
val varDef = defineVar("ifres", expr.tpe)
def branchWithAssign(orig: Tree) = orig match {
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
case _ => Assign(Ident(varDef.name), orig)
}
val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
}
case Match(scrut, cases) =>
// if type of match is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
}
else {
val varDef = defineVar("matchres", expr.tpe)
val casesWithAssign = cases map {
case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))
case CaseDef(pat, guard, body) => CaseDef(pat, guard, Assign(Ident(varDef.name), body))
}
val matchWithAssign = Match(scrut, casesWithAssign)
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
}
case _ =>
stats :+ expr
}
}
def transformToList(trees: List[Tree]): List[Tree] = trees match {
case fst :: rest => transformToList(fst) ++ transformToList(rest)
case Nil => Nil
}
def transformToBlock(tree: Tree): Block = transformToList(tree) match {
case stats :+ expr => Block(stats, expr)
}
def liftedName(prefix: String) = c.fresh(prefix + "$")
private def defineVar(prefix: String, tp: Type): ValDef =
ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp))
private def defineVal(prefix: String, lhs: Tree): ValDef =
ValDef(NoMods, liftedName(prefix), TypeTree(), lhs)
}
object anf {
def transformToList(tree: Tree): List[Tree] = {
def containsAwait = tree exists isAwait
tree match {
case Select(qual, sel) if containsAwait =>
val stats :+ expr = inline.transformToList(qual)
stats :+ Select(expr, sel).setSymbol(tree.symbol)
case Apply(fun, args) if containsAwait =>
// we an assume that no await call appears in a by-name argument position,
// this has already been checked.
val funStats :+ simpleFun = inline.transformToList(fun)
val argLists = args map inline.transformToList
val allArgStats = argLists flatMap (_.init)
val simpleArgs = argLists map (_.last)
funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)
case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec.
inline.transformToList(stats :+ expr)
case ValDef(mods, name, tpt, rhs) if containsAwait =>
if (rhs exists isAwait) {
val stats :+ expr = inline.transformToList(rhs)
stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
} else List(tree)
case Assign(lhs, rhs) if containsAwait =>
val stats :+ expr = inline.transformToList(rhs)
stats :+ Assign(lhs, expr)
case If(cond, thenp, elsep) if containsAwait =>
val stats :+ expr = inline.transformToList(cond)
val thenBlock = inline.transformToBlock(thenp)
val elseBlock = inline.transformToBlock(elsep)
stats :+
c.typeCheck(If(expr, thenBlock, elseBlock))
case Match(scrut, cases) if containsAwait =>
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
val caseDefs = cases map {
case CaseDef(pat, guard, body) =>
val block = inline.transformToBlock(body)
CaseDef(pat, guard, block)
}
scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs))
case TypeApply(fun, targs) if containsAwait =>
val funStats :+ simpleFun = inline.transformToList(fun)
funStats :+ TypeApply(simpleFun, targs).setSymbol(tree.symbol)
case _ =>
List(tree)
}
}
}
}
|