aboutsummaryrefslogtreecommitdiff
path: root/src/async/library/scala/async/Async.scala
blob: 2c81bc3cce1d05791bcbf86f8e3944ab4aa65d98 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/**
 * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
 */
package scala.async

import language.experimental.macros

import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import scala.concurrent.{ Future, Promise, ExecutionContext, future }
import ExecutionContext.Implicits.global
import scala.util.control.NonFatal
import scala.util.continuations.{ shift, reset, cpsParam }

/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */
class FallbackToCpsException extends scala.util.control.ControlThrowable

/*
 * @author Philipp Haller
 */
object Async extends AsyncUtils {

  def async[T](body: T): Future[T] = macro asyncImpl[T]
  
  def await[T](awaitable: Future[T]): T = ???
  
  /* Fall back for `await` when it is called at an unsupported position.
   */
  def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] =
    shift {
      (k: (T => U)) =>
        awaitable onComplete {
          case tr => p.success(k(tr.get))
        }
    }
  
  def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = {
    import c.universe._
    import Flag._
    
    val builder = new ExprBuilder[c.type](c)
    val awaitMethod = awaitSym(c)
    
    try {
      body.tree match {
        case Block(stats, expr) =>
          val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map())

          vprintln(s"states of current method:")
          asyncBlockBuilder.asyncStates foreach vprintln

          val handlerExpr = asyncBlockBuilder.mkCombinedHandlerExpr()

          vprintln(s"GENERATED handler expr:")
          vprintln(handlerExpr)

          val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
            val tree = Apply(Select(Ident("result"), newTermName("success")),
              List(asyncBlockBuilder.asyncStates.last.body))
            builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree))
          }

          vprintln("GENERATED handler for last state:")
          vprintln(handlerForLastState)

          val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList
          
          /*
            def resume(): Unit = {
              try {
                (handlerExpr.splice orElse handlerForLastState.splice)(state)
              } catch {
                case NonFatal(t) => result.failure(t)
              }
            }
           */
          val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass),
            Try(Apply(Select(
              Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)),
              newTermName("apply")), List(Ident(newTermName("state")))),
              List(
                CaseDef(
                  Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))),
                  EmptyTree,
                  Block(List(
                    Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))),
                    Literal(Constant(()))))), EmptyTree))
          
          reify {
            val result = Promise[T]()
            var state = 0
            future {
              c.Expr(Block(
                localVarTrees :+ resumeFunTree,
                Apply(Ident(newTermName("resume")), List()))).splice
            }
            result.future
          }

        case _ =>
          // issue error message
          reify {
            sys.error("expression not supported by async")
          }
      }
    } catch {
      case _: FallbackToCpsException =>
        // replace `await` invocations with `awaitCps` invocations
        val awaitReplacer = new Transformer {
          val awaitCpsMethod = awaitCpsSym(c)
          override def transform(tree: Tree): Tree = tree match {
            case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod =>
              val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
              treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p")))
              
            case _ =>
              super.transform(tree)
          }
        }
        
        val newBody = awaitReplacer.transform(body.tree)
        
        reify {
          val p = Promise[T]()
          future {
            reset {
              c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice
            }
          }
          p.future
        }
    }
  }

}