aboutsummaryrefslogblamecommitdiff
path: root/src/main/scala/scala/async/AsyncAnalysis.scala
blob: 8bb5bcd86dbc3c96a4373cb95e8f0aa6cffecd97 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                             


                                   
                               
 
                                                                   

                     
                                       
 

                
















                                                                                             
                                                                   

                                                 

                                                                                                          

   
                                                                 

                                                                          

                                                                                                      

                                                                                                          
       


                                                  



                                                                                                

     



                                                     







                                                         


                                             
                                             
                                                   
                              
                                             
                                                                     
                                             



                              




                                                                                       





                                                                               
                        
     

   
                                                                   





                                                            


























                                                                                   


                                         
                                                                 
                                                  
                                                                 
                                           

                                                                 
                                                                 

                              
                                                                 

                                                      


                                                                          
                                
                                                                                                                    
 
                                                                        
                                                                                                                                                                                             
                               

                              
                                                                 





                                                                   
                                                                                      








                                                   
 
 
/*
 * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
 */

package scala.async

import scala.reflect.macros.Context
import scala.collection.mutable

private[async] final case class AsyncAnalysis[C <: Context](c: C) {
  import c.universe._

  val utils = TransformUtils[c.type](c)

  import utils._

  /**
   * Analyze the contents of an `async` block in order to:
   * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
   *
   * Must be called on the original tree, not on the ANF transformed tree.
   */
  def reportUnsupportedAwaits(tree: Tree) {
    new UnsupportedAwaitAnalyzer().traverse(tree)
  }

  /**
   * Analyze the contents of an `async` block in order to:
   * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
   * on whether or not they are accessed only from a single state.
   *
   * Must be called on the ANF transformed tree.
   */
  def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
    val analyzer = new AsyncDefinitionUseAnalyzer
    analyzer.traverse(tree)
    val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
    liftable
  }

  private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
    override def nestedClass(classDef: ClassDef) {
      val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
      if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
        // do not allow local class definitions, because of SI-5467 (specific to case classes, though)
        if (classDef.symbol.asClass.isCaseClass)
          c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
      }
    }

    override def nestedModule(module: ModuleDef) {
      if (!reportUnsupportedAwait(module, "nested object")) {
        // local object definitions lead to spurious type errors (because of resetAllAttrs?)
        c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
      }
    }

    override def nestedMethod(module: DefDef) {
      reportUnsupportedAwait(module, "nested method")
    }

    override def byNameArgument(arg: Tree) {
      reportUnsupportedAwait(arg, "by-name argument")
    }

    override def function(function: Function) {
      reportUnsupportedAwait(function, "nested function")
    }

    override def traverse(tree: Tree) {
      def containsAwait = tree exists isAwait
      tree match {
        case Try(_, _, _) if containsAwait =>
          reportUnsupportedAwait(tree, "try/catch")
          super.traverse(tree)
        case Return(_)                     =>
          c.abort(tree.pos, "return is illegal within a async block")
        case _                             =>
          super.traverse(tree)
      }
    }

    /**
     * @return true, if the tree contained an unsupported await.
     */
    private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
      val badAwaits: List[RefTree] = tree collect {
        case rt: RefTree if isAwait(rt) => rt
      }
      badAwaits foreach {
        tree =>
          c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
      }
      badAwaits.nonEmpty
    }
  }

  private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
    private var chunkId = 0

    private def nextChunk() = chunkId += 1

    private var valDefChunkId = Map[Symbol, (ValDef, Int)]()

    val valDefsToLift      : mutable.Set[ValDef] = collection.mutable.Set()
    val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()

    override def nestedMethod(defDef: DefDef) {
      nestedMethodsToLift += defDef
      defDef.rhs foreach {
        case rt: RefTree =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) =>
              valDefsToLift += vd // lift all vals referred to by nested methods.
            case _                      =>
          }
        case _           =>
      }
    }

    override def function(function: Function) {
      function foreach {
        case rt: RefTree =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) =>
              valDefsToLift += vd // lift all vals referred to by nested functions.
            case _                      =>
          }
        case _           =>
      }
    }

    override def traverse(tree: Tree) = {
      tree match {
        case If(cond, thenp, elsep) if tree exists isAwait     =>
          traverseChunks(List(cond, thenp, elsep))
        case Match(selector, cases) if tree exists isAwait     =>
          traverseChunks(selector :: cases)
        case LabelDef(name, params, rhs) if rhs exists isAwait =>
          traverseChunks(rhs :: Nil)
        case Apply(fun, args) if isAwait(fun)                  =>
          super.traverse(tree)
          nextChunk()
        case vd: ValDef                                        =>
          super.traverse(tree)
          valDefChunkId += (vd.symbol ->(vd, chunkId))
          val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
          if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
        case as: Assign                                        =>
          if (isAwait(as.rhs)) {
            assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)

            // TODO test the orElse case, try to remove the restriction.
            val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
            valDefsToLift += vd
          }
          super.traverse(tree)
        case rt: RefTree                                       =>
          valDefChunkId.get(rt.symbol) match {
            case Some((vd, defChunkId)) if defChunkId != chunkId =>
              valDefsToLift += vd
            case _                                               =>
          }
          super.traverse(tree)
        case _                                                 => super.traverse(tree)
      }
    }

    private def traverseChunks(trees: List[Tree]) {
      trees.foreach {
        t => traverse(t); nextChunk()
      }
    }
  }

}