diff options
Diffstat (limited to 'src/main/scala/scala/async/TransformUtils.scala')
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 62 |
1 files changed, 56 insertions, 6 deletions
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index b79be87..8838bb3 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ package scala.async @@ -9,11 +9,34 @@ import reflect.ClassTag /** * Utilities used in both `ExprBuilder` and `AnfTransform`. */ -class TransformUtils[C <: Context](val c: C) { +private[async] final case class TransformUtils[C <: Context](val c: C) { import c.universe._ - protected def defaultValue(tpe: Type): Literal = { + object name { + def suffix(string: String) = string + "$async" + + def suffixedName(prefix: String) = newTermName(suffix(prefix)) + + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + + // TODO do we need to freshen any of these? + val tr = newTermName("tr") + val onCompleteHandler = suffixedName("onCompleteHandler") + + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") + } + + def defaultValue(tpe: Type): Literal = { val defaultValue: Any = if (tpe <:< definitions.BooleanTpe) false else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 @@ -22,9 +45,22 @@ class TransformUtils[C <: Context](val c: C) { Literal(Constant(defaultValue)) } - protected def isAwait(fun: Tree) = + def isAwait(fun: Tree) = fun.symbol == defn.Async_await + /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node + * referring to the corresponding new name + */ + def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = { + val renamer = new Transformer { + override def transform(tree: Tree) = tree match { + case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_)) + case _ => super.transform(tree) + } + } + renamer.transform(tree) + } + /** Descends into the regions of the tree that are subject to the * translation to a state machine by `async`. When a nested template, * function, or by-name argument is encountered, the descend stops, @@ -37,6 +73,9 @@ class TransformUtils[C <: Context](val c: C) { def nestedModule(module: ModuleDef) { } + def nestedMethod(module: DefDef) { + } + def byNameArgument(arg: Tree) { } @@ -47,6 +86,7 @@ class TransformUtils[C <: Context](val c: C) { tree match { case cd: ClassDef => nestedClass(cd) case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case Apply(fun, args) => val isInByName = isByName(fun) @@ -68,7 +108,7 @@ class TransformUtils[C <: Context](val c: C) { Set(Boolean_&&, Boolean_||) } - protected def isByName(fun: Tree): (Int => Boolean) = { + def isByName(fun: Tree): (Int => Boolean) = { if (Boolean_ShortCircuits contains fun.symbol) i => true else fun.tpe match { case MethodType(params, _) => @@ -78,7 +118,16 @@ class TransformUtils[C <: Context](val c: C) { } } - private[async] object defn { + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) + } + + object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) } @@ -157,4 +206,5 @@ class TransformUtils[C <: Context](val c: C) { def ValDef(tree: Tree)(mods: Modifiers, name: TermName, tpt: Tree, rhs: Tree): ValDef = copyAttach(tree, c.universe.ValDef(mods, name, tpt, rhs)) } + } |