aboutsummaryrefslogblamecommitdiff
path: root/src/async/library/scala/async/Async.scala
blob: 60b9f44eaf2a60edb67de803c50028f4eec3df34 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13












                                                             


                         
                                                          


                     
               
  






                                                                                    
  









                                                        





























                                                                                                                                                                                                                                                            
   
















                                              





                                                                                                   
















































































                                                                                                         

 
 








                                                       
                                                                                     

                       
                                            
                                 
    

                                

                                                                                 






                                                                                          
                                                          













                                                                                          
                                                                                                

                                                    
                                                                           





                                                                                                                                       
                                                                                        













                                                                                         


                                                                                      
                                                                     
         







































                                                                
/**
 * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
 */
package scala.async

import language.experimental.macros
import scala.reflect.macros.Context
import scala.reflect.runtime.universe

import scala.concurrent.{ Future, Promise }
import scala.util.control.NonFatal
import scala.collection.mutable.ListBuffer

/*
 * @author Philipp Haller
 */
class ExprBuilder[C <: Context with Singleton](val c: C) {
  builder =>
  
  import c.universe._
  import Flag._
  
  /* Make a partial function literal handling case #num:
   * 
   *     {
   *       case any if any == num => rhs
   *     }
   */
  def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = {
/*
    val numLiteral = c.Expr[Int](Literal(Constant(num)))
    
    reify(new PartialFunction[Int, Unit] {
      def isDefinedAt(`x$1`: Int) =
        `x$1` == numLiteral.splice
      def apply(`x$1`: Int) = `x$1` match {
        case any: Int if any == numLiteral.splice =>
          rhs.splice
      }
    })
*/
    val rhsTree = c.resetAllAttrs(rhs.tree.duplicate)
    val handlerTree = mkHandlerTree(num, rhsTree)
    c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
  }

  def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = {
    val partFunClass = c.mirror.staticClass("scala.PartialFunction")
    val partFunIdent = Ident(partFunClass)
    val intIdent = Ident(definitions.IntClass)
    val unitIdent = Ident(definitions.UnitClass)
    
    Block(List(
      ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))),
        emptyValDef, List(
          
          DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))),
          
          DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), Apply(Select(Ident(newTermName("x$1")), newTermName("$eq$eq")), List(Literal(Constant(num))))),
          
          DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
            Match(Ident(newTermName("x$1")), List(
              CaseDef(Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), intIdent)), Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), rhs)
            ))
          )
          
        ))
      )),
      Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List())
    )
  }
  
  class AsyncStateBuilder {
    /* Statements preceding an await call. */
    private val stats = ListBuffer[c.Tree]()
    
    /* Argument of an await call. */
    var awaitable: c.Tree = null
    
    /* Result name of an await call. */
    var resultName: c.universe.TermName = null
    
    /* Result type of an await call. */
    var resultType: c.universe.Type = null
    
    def += (stat: c.Tree): Unit =
      stats += stat
    
    /* Result needs to be created as a var at the beginning of the transformed method body, so that
     * it is visible in subsequent states of the state machine.
     *
     * @param awaitArg         the argument of await
     * @param awaitResultName  the name of the variable that the result of await is assigned to
     * @param awaitResultType  the type of the result of await
     */
    def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): Unit = {
      awaitable = c.resetAllAttrs(awaitArg.duplicate)
      resultName = awaitResultName
      resultType = awaitResultType.tpe
    }
    
    override def toString: String = {
      val statsBeforeAwait = stats.mkString("\n")
      s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
    }
    
    /* Make an `onComplete` invocation:
     * 
     *     awaitable.onComplete {
     *       case tr =>
     *         resultName = tr.get
     *         resume()
     *     }
     */
    def mkOnCompleteTree: c.Tree = {
      val assignTree =
        Assign(
          Ident(resultName.toString),
          Select(Ident("tr"), c.universe.newTermName("get"))
        )
      val handlerTree =
        Match(
          EmptyTree,
          List(
            CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
              Block(assignTree, Apply(Ident("resume"), List())) // rhs of case
            )
          )
        )
      Apply(
        Select(awaitable, c.universe.newTermName("onComplete")),
        List(handlerTree)
      )
    }
    
    /* Make a partial function literal handling case #num:
     * 
     *     {
     *       case any if any == num =>
     *         stats
     *         awaitable.onComplete {
     *           case tr =>
     *             resultName = tr.get
     *             resume()
     *         }
     *     }
     */
    def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = {
      assert(awaitable != null)
      val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate))
      val block = Block((nakedStats :+ mkOnCompleteTree): _*)
      builder.mkHandler(num, c.Expr[Unit](block))
    }
    
    def lastExprTree: c.Tree = {
      assert(awaitable == null)
      if (stats.size == 1)
        c.resetAllAttrs(stats(0).duplicate)
      else {
        val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate))
        Block(nakedStats: _*)
      }
    }
    
    //TODO: complete for other primitive types, how to handle value classes?
    def varDefForResult: c.Tree = {
      val rhs =
        if (resultType <:< definitions.IntTpe) Literal(Constant(0))
        else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
        else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
        else Literal(Constant(null))
      ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
    }
  }

}


/*
 * @author Philipp Haller
 */
object Async extends AsyncUtils {

  def async[T](body: T): Future[T] = macro asyncImpl[T]
  
  def await[T](awaitable: Future[T]): T = ???
  
  def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = {
    import c.universe._
    
    val builder = new ExprBuilder[c.type](c)
    val awaitMethod = awaitSym(c)
    
    body.tree match {
      case Block(stats, expr) =>
        val asyncStates = ListBuffer[builder.AsyncStateBuilder]()
        var stateBuilder = new builder.AsyncStateBuilder // current state builder
        
        for (stat <- stats) {
          stat match {
            // the val name = await(..) pattern
            case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
              stateBuilder.complete(args(0), name, tpt)
              asyncStates += stateBuilder
              stateBuilder = new builder.AsyncStateBuilder
              
            case _ =>
              stateBuilder += stat
          }
        }
        // complete last state builder (representing the expressions after the last await)
        stateBuilder += expr
        asyncStates += stateBuilder
        
        vprintln("states of current method:")
        asyncStates foreach vprintln
        
        // also return index of last state
        def buildHandlerExpr(): (c.Expr[PartialFunction[Int, Unit]], Int) = {
          var handlerExpr = asyncStates(0).mkHandlerForState(1) // state 0 but { case 1 => ... }
          var i = 1
          while (asyncStates(i).awaitable != null) {
            val handlerForNextState = asyncStates(i).mkHandlerForState(i+1)
            val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
            handlerExpr = reify {
              c.Expr(currentHandlerTreeNaked).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]].splice orElse handlerForNextState.splice
            }
            i += 1
          }
          // asyncStates(i) does not end with `await` (asyncStates(i).awaitable == null)
          (handlerExpr, i)
        }
        
        val (handlerExpr, indexOfLastState) = buildHandlerExpr()
        vprintln(s"GENERATED handler expr ($indexOfLastState):")
        vprintln(handlerExpr)
        
        val localVarDefs = ListBuffer[c.Tree]()
        for (state <- asyncStates.init) // exclude last state (doesn't have await result)
          localVarDefs += state.varDefForResult
        // pad up to 5 var defs
        if (localVarDefs.size < 5)
          for (_ <- localVarDefs.size until 5) localVarDefs += EmptyTree
        
        val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
          val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")),
                           List(asyncStates(indexOfLastState).lastExprTree))
          builder.mkHandler(indexOfLastState + 1, c.Expr[Unit](tree))
        }
        
        vprintln("GENERATED handler for last state:")
        vprintln(handlerForLastState)
        
        reify {
          val result = Promise[T]()
          var state = 0
          
          c.Expr(localVarDefs(0)).splice
          c.Expr(localVarDefs(1)).splice
          c.Expr(localVarDefs(2)).splice
          c.Expr(localVarDefs(3)).splice
          c.Expr(localVarDefs(4)).splice
          
          def resume(): Unit = {
            state += 1
            
            var handler: PartialFunction[Int, Unit] =
              handlerExpr.splice
            
            try {
              (handler orElse handlerForLastState.splice)(state)
            } catch {
              case NonFatal(t) => result.failure(t)
            }
          }

          resume()
          result.future
        }

      case _ =>
        // issue error message
        reify {
          sys.error("expression not supported by async")
        }
    }
  }

}