package scala.async import scala.reflect.macros.Context class AnfTransform[C <: Context](val c: C) { import c.universe._ import AsyncUtils._ object inline { //TODO: DRY private def defaultValue(tpe: Type): Literal = { val defaultValue: Any = if (tpe <:< definitions.BooleanTpe) false else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 else if (tpe <:< definitions.AnyValTpe) 0 else null Literal(Constant(defaultValue)) } def transformToList(tree: Tree): List[Tree] = { val stats :+ expr = anf.transformToList(tree) expr match { case Apply(fun, args) if { vprintln("check fun.toString: " + fun.toString); fun.toString.startsWith("scala.async.Async.await") } => vprintln("found await!!") val liftedName = c.fresh("await$") stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) case If(cond, thenp, elsep) => val liftedName = c.fresh("ifres$") val varDef = ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) val thenWithAssign = thenp match { case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) case _ => Assign(Ident(liftedName), thenp) } val elseWithAssign = elsep match { case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) case _ => Assign(Ident(liftedName), elsep) } val ifWithAssign = If(cond, thenWithAssign, elseWithAssign) stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) case _ => vprintln("found something else") stats :+ expr } } def transformToList(trees: List[Tree]): List[Tree] = trees match { case fst :: rest => transformToList(fst) ++ transformToList(rest) case Nil => Nil } } object anf { def transformToList(tree: Tree): List[Tree] = tree match { case Select(qual, sel) => val stats :+ expr = inline.transformToList(qual) stats :+ Select(expr, sel) case Apply(fun, args) => val funStats :+ simpleFun = inline.transformToList(fun) val argLists = args map inline.transformToList val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs) case Block(stats, expr) => inline.transformToList(stats) ++ inline.transformToList(expr) case ValDef(mods, name, tpt, rhs) => val stats :+ expr = inline.transformToList(rhs) stats :+ ValDef(mods, name, tpt, expr) case Assign(name, rhs) => val stats :+ expr = inline.transformToList(rhs) stats :+ Assign(name, expr) case If(cond, thenp, elsep) => val stats :+ expr = inline.transformToList(cond) val thenStats :+ thenExpr = inline.transformToList(thenp) val elseStats :+ elseExpr = inline.transformToList(elsep) stats :+ c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)), lub(List(thenp.tpe, elsep.tpe))) //TODO case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree) case TypeApply(fun, targs) => val funStats :+ simpleFun = inline.transformToList(fun) funStats :+ TypeApply(simpleFun, targs) //TODO case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree) case ClassDef(mods, name, tparams, impl) => List(tree) case ModuleDef(mods, name, impl) => List(tree) case _ => println("do not handle tree "+tree) ??? } } }