aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/TransformUtils.scala
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-22 17:50:50 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-22 17:50:50 +0100
commit087d1e4e138eccf4b2d420298affb4289632bf73 (patch)
treefd0fc1c034f4cbc2d92fa7958c6b03c59e23aa92 /src/main/scala/scala/async/TransformUtils.scala
parent1c91fec998d09e31c2c52760452af1771a092182 (diff)
downloadscala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.gz
scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.bz2
scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.zip
Support match as an expression.
- corrects detection of await calls in the ANF transform. - Split AsyncAnalyzer into two parts. Unsupported await detection must happen prior to the async transform to prevent the ANF lifting out by-name arguments to vals and hence changing the semantics.
Diffstat (limited to 'src/main/scala/scala/async/TransformUtils.scala')
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index d36c277..b8b21a3 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -9,6 +9,7 @@ import scala.reflect.macros.Context
* Utilities used in both `ExprBuilder` and `AnfTransform`.
*/
class TransformUtils[C <: Context](val c: C) {
+
import c.universe._
protected def defaultValue(tpe: Type): Literal = {
@@ -19,4 +20,101 @@ class TransformUtils[C <: Context](val c: C) {
else null
Literal(Constant(defaultValue))
}
+
+ protected def isAwait(fun: Tree) =
+ fun.symbol == defn.Async_await
+
+ /** Descends into the regions of the tree that are subject to the
+ * translation to a state machine by `async`. When a nested template,
+ * function, or by-name argument is encountered, the descend stops,
+ * and `nestedClass` etc are invoked.
+ */
+ trait AsyncTraverser extends Traverser {
+ def nestedClass(classDef: ClassDef) {
+ }
+
+ def nestedModule(module: ModuleDef) {
+ }
+
+ def byNameArgument(arg: Tree) {
+ }
+
+ def function(function: Function) {
+ }
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case fun: Function => function(fun)
+ case Apply(fun, args) =>
+ val isInByName = isByName(fun)
+ for ((arg, index) <- args.zipWithIndex) {
+ if (!isInByName(index)) traverse(arg)
+ else byNameArgument(arg)
+ }
+ traverse(fun)
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+
+ private lazy val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ protected def isByName(fun: Tree): (Int => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) i => true
+ else fun.tpe match {
+ case MethodType(params, _) =>
+ val isByNameParams = params.map(_.asTerm.isByNameParam)
+ (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
+ case _ => Map()
+ }
+ }
+
+ private[async] object defn {
+ def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
+ c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ }
+
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
+
+ def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
+ self.splice.apply(arg.splice)
+ }
+
+ def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
+ self.splice == other.splice
+ }
+
+ def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
+ self.splice.get
+ }
+
+ val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
+
+ val TryClass = c.mirror.staticClass("scala.util.Try")
+ val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
+ val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
+
+ val Async_await = {
+ val asyncMod = c.mirror.staticModule("scala.async.Async")
+ val tpe = asyncMod.moduleClass.asType.toType
+ tpe.member(c.universe.newTermName("await"))
+ }
+ }
+
+
+ /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
+ private def methodSym(apply: c.Expr[Any]): Symbol = {
+ val tree2: Tree = c.typeCheck(apply.tree)
+ tree2.collect {
+ case s: SymTree if s.symbol.isMethod => s.symbol
+ }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
+ }
}