package scala.async
import scala.reflect.macros.Context
class AnfTransform[C <: Context](val c: C) {
import c.universe._
import AsyncUtils._
object inline {
//TODO: DRY
private def defaultValue(tpe: Type): Literal = {
val defaultValue: Any =
if (tpe <:< definitions.BooleanTpe) false
else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
else if (tpe <:< definitions.AnyValTpe) 0
else null
Literal(Constant(defaultValue))
}
def transformToList(tree: Tree): List[Tree] = {
val stats :+ expr = anf.transformToList(tree)
expr match {
case Apply(fun, args) if { vprintln("check fun.toString: " + fun.toString); fun.toString.startsWith("scala.async.Async.await") } =>
vprintln("found await!!")
val liftedName = c.fresh("await$")
stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName)
case If(cond, thenp, elsep) =>
val liftedName = c.fresh("ifres$")
val varDef =
ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe))
val thenWithAssign = thenp match {
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr))
case _ => Assign(Ident(liftedName), thenp)
}
val elseWithAssign = elsep match {
case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr))
case _ => Assign(Ident(liftedName), elsep)
}
val ifWithAssign =
If(cond, thenWithAssign, elseWithAssign)
stats :+ varDef :+ ifWithAssign :+ Ident(liftedName)
case _ =>
vprintln("found something else")
stats :+ expr
}
}
def transformToList(trees: List[Tree]): List[Tree] = trees match {
case fst :: rest => transformToList(fst) ++ transformToList(rest)
case Nil => Nil
}
}
object anf {
def transformToList(tree: Tree): List[Tree] = tree match {
case Select(qual, sel) =>
val stats :+ expr = inline.transformToList(qual)
stats :+ Select(expr, sel)
case Apply(fun, args) =>
val funStats :+ simpleFun = inline.transformToList(fun)
val argLists = args map inline.transformToList
val allArgStats = argLists flatMap (_.init)
val simpleArgs = argLists map (_.last)
funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs)
case Block(stats, expr) =>
inline.transformToList(stats) ++ inline.transformToList(expr)
case ValDef(mods, name, tpt, rhs) =>
val stats :+ expr = inline.transformToList(rhs)
stats :+ ValDef(mods, name, tpt, expr)
case Assign(name, rhs) =>
val stats :+ expr = inline.transformToList(rhs)
stats :+ Assign(name, expr)
case If(cond, thenp, elsep) =>
val stats :+ expr = inline.transformToList(cond)
val thenStats :+ thenExpr = inline.transformToList(thenp)
val elseStats :+ elseExpr = inline.transformToList(elsep)
stats :+
c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)),
lub(List(thenp.tpe, elsep.tpe)))
//TODO
case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree)
case TypeApply(fun, targs) =>
val funStats :+ simpleFun = inline.transformToList(fun)
funStats :+ TypeApply(simpleFun, targs)
//TODO
case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree)
case ClassDef(mods, name, tparams, impl) => List(tree)
case ModuleDef(mods, name, impl) => List(tree)
case _ =>
println("do not handle tree "+tree)
???
}
}
}