package scala.async.run.late import java.io.File import junit.framework.Assert.assertEquals import org.junit.Test import scala.annotation.StaticAnnotation import scala.async.internal.{AsyncId, AsyncMacro} import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader import scala.tools.nsc._ import scala.tools.nsc.plugins.{Plugin, PluginComponent} import scala.tools.nsc.reporters.StoreReporter import scala.tools.nsc.transform.TypingTransformers // Tests for customized use of the async transform from a compiler plugin, which // calls it from a new phase that runs after patmat. class LateExpansion { @Test def test0(): Unit = { val result = wrapAndRun( """ | @autoawait def id(a: String) = a | id("foo") + id("bar") | """.stripMargin) assertEquals("foobar", result) } @Test def testGuard(): Unit = { val result = wrapAndRun( """ | @autoawait def id[A](a: A) = a | "" match { case _ if id(false) => ???; case _ => "okay" } | """.stripMargin) assertEquals("okay", result) } @Test def testExtractor(): Unit = { val result = wrapAndRun( """ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } | "" match { case Extractor(a, b) if "".isEmpty => a == b } | """.stripMargin) assertEquals(true, result) } @Test def testNestedMatchExtractor(): Unit = { val result = wrapAndRun( """ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } | "" match { | case _ if "".isEmpty => | "" match { case Extractor(a, b) => a == b } | } | """.stripMargin) assertEquals(true, result) } @Test def testCombo(): Unit = { val result = wrapAndRun( """ | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) } | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) } | @autoawait def id(a: String) = a | println("Test.test") | val r1 = Predef.identity("blerg") match { | case x if " ".isEmpty => "case 2: " + x | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y | x match { | case Extractor1(Extractor2(x), y: String) => | case _ => | } | case Extractor2(x) => "case 3: " + x | } | r1 | """.stripMargin) assertEquals("case 3: blerg3", result) } def wrapAndRun(code: String): Any = { run( s""" |import scala.async.run.late.{autoawait,lateasync} |object Test { | @lateasync | def test: Any = { | $code | } |} | """.stripMargin) } def run(code: String): Any = { val reporter = new StoreReporter val settings = new Settings(println(_)) settings.outdir.value = sys.props("java.io.tmpdir") settings.embeddedDefaults(getClass.getClassLoader) val isInSBT = !settings.classpath.isSetByUser if (isInSBT) settings.usejavacp.value = true val global = new Global(settings, reporter) { self => object late extends { val global: self.type = self } with LatePlugin override protected def loadPlugins(): List[Plugin] = late :: Nil } import global._ val run = new Run val source = newSourceFile(code) run.compileSources(source :: Nil) assert(!reporter.hasErrors, reporter.infos.mkString("\n")) val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) val cls = loader.loadClass("Test") cls.getMethod("test").invoke(null) } } abstract class LatePlugin extends Plugin { import global._ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { val global: LatePlugin.this.global.type = LatePlugin.this.global lazy val asyncIdSym = symbolOf[AsyncId.type] lazy val asyncSym = asyncIdSym.info.member(TermName("async")) lazy val awaitSym = asyncIdSym.info.member(TermName("await")) lazy val autoAwaitSym = symbolOf[autoawait] lazy val lateAsyncSym = symbolOf[lateasync] def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) { override def transform(tree: Tree): Tree = { super.transform(tree) match { case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { val expandee = localTyper.context.withMacrosDisabled( localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs))) ) val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee) val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs) val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any]) deriveDefDef(dd)(_ => localTyper.typed(code)) } case x => x } } } override def newPhase(prev: Phase): Phase = new StdPhase(prev) { override def apply(unit: CompilationUnit): Unit = { val translated = newTransformer(unit).transformUnit(unit) //println(show(unit.body)) translated } } override val runsAfter: List[String] = "patmat" :: Nil override val phaseName: String = "postpatmat" }) override val description: String = "postpatmat" override val name: String = "postpatmat" } // Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { }` final class lateasync extends StaticAnnotation // Calls to methods with this annotation are translated to `AsyncId.await()` final class autoawait extends StaticAnnotation