aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/Async.scala
blob: c480b627c9330391743d8e95c99c27334052acb8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
/*
 * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
 */

package scala.async

import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.reflect.internal.annotations.compileTimeOnly

object Async extends AsyncBase {

  import scala.concurrent.Future

  lazy val futureSystem = ScalaConcurrentFutureSystem
  type FS = ScalaConcurrentFutureSystem.type

  def async[T](body: T) = macro asyncImpl[T]

  override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
}

object AsyncId extends AsyncBase {
  lazy val futureSystem = IdentityFutureSystem
  type FS = IdentityFutureSystem.type

  def async[T](body: T) = macro asyncImpl[T]

  override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body)
}

/**
 * A base class for the `async` macro. Subclasses must provide:
 *
 * - Concrete types for a given future system
 * - Tree manipulations to create and complete the equivalent of Future and Promise
 * in that system.
 * - The `async` macro declaration itself, and a forwarder for the macro implementation.
 * (The latter is temporarily needed to workaround bug SI-6650 in the macro system)
 *
 * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`.
 */
abstract class AsyncBase {
  self =>

  type FS <: FutureSystem
  val futureSystem: FS

  /**
   * A call to `await` must be nested in an enclosing `async` block.
   *
   * A call to `await` does not block the current thread, rather it is a delimiter
   * used by the enclosing `async` macro. Code following the `await`
   * call is executed asynchronously, when the argument of `await` has been completed.
   *
   * @param awaitable the future from which a value is awaited.
   * @tparam T        the type of that value.
   * @return          the value.
   */
  @compileTimeOnly("`await` must be enclosed in an `async` block")
  def await[T](awaitable: futureSystem.Fut[T]): T = ???

  protected[async] def fallbackEnabled = false

  def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
    import c.universe._

    val analyzer = AsyncAnalysis[c.type](c, this)
    val utils = TransformUtils[c.type](c)
    import utils.{name, defn}

    analyzer.reportUnsupportedAwaits(body.tree)

    // Transform to A-normal form:
    //  - no await calls in qualifiers or arguments,
    //  - if/match only used in statement position.
    val anfTree: Block = {
      val anf = AnfTransform[c.type](c)
      val restored = utils.restorePatternMatchingFunctions(body.tree)
      val stats1 :+ expr1 = anf(restored)
      val block = Block(stats1, expr1)
      c.typeCheck(block).asInstanceOf[Block]
    }

    // Analyze the block to find locals that will be accessed from multiple
    // states of our generated state machine, e.g. a value assigned before
    // an `await` and read afterwards.
    val renameMap: Map[Symbol, TermName] = {
      analyzer.defTreesUsedInSubsequentStates(anfTree).map {
        vd =>
          (vd.symbol, name.fresh(vd.name.toTermName))
      }.toMap
    }

    val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
    import builder.futureSystemOps
    val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
    import asyncBlock.asyncStates
    logDiagnostics(c)(anfTree, asyncStates.map(_.toString))

    // Important to retain the original declaration order here!
    val localVarTrees = anfTree.collect {
      case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol                            =>
        utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
      case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol =>
        DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
    }

    val onCompleteHandler = {
      Function(
        List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.resultType[Any]), EmptyTree)),
        asyncBlock.onCompleteHandler)
    }
    val resumeFunTree = asyncBlock.resumeFunTree[T]

    val stateMachineType = futureSystemOps.stateMachineType[T]

    lazy val stateMachine: ClassDef = {
      val body: List[Tree] = {
        val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
        val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
        val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
        val applyDefDef: DefDef = {
          val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.resultType[Any]), EmptyTree)))
          val applyBody = asyncBlock.onCompleteHandler
          DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody)
        }
        val apply0DefDef: DefDef = {
          // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
          // See SI-1247 for the the optimization that avoids creatio
          val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.resultType[Any]), EmptyTree)))
          val applyBody = asyncBlock.onCompleteHandler
          DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
        }
        List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
      }
      val template = {
        Template(List(TypeTree(stateMachineType)), emptyValDef, body)
      }
      ClassDef(NoMods, name.stateMachineT, Nil, template)
    }

    def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)

    val code: c.Expr[futureSystem.Fut[T]] = {
      val isSimple = asyncStates.size == 1
      val tree =
        if (isSimple)
          Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
        else {
          Block(List[Tree](
            stateMachine,
            ValDef(NoMods, name.stateMachine, TypeTree(stateMachineType), Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)),
            futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
          ),
          futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
        }
      c.Expr[futureSystem.Fut[T]](tree)
    }

    AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
    code
  }

  def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
    def location = try {
      c.macroApplication.pos.source.path
    } catch {
      case _: UnsupportedOperationException =>
        c.macroApplication.pos.toString
    }

    AsyncUtils.vprintln(s"In file '$location':")
    AsyncUtils.vprintln(s"${c.macroApplication}")
    AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
    states foreach (s => AsyncUtils.vprintln(s))
  }
}

/** Internal class used by the `async` macro; should not be manually extended by client code */
abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
  def result$async: Result

  def execContext$async: EC
}