diff options
Diffstat (limited to 'src/test/scala/scala/async/run/late/LateExpansion.scala')
-rw-r--r-- | src/test/scala/scala/async/run/late/LateExpansion.scala | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala new file mode 100644 index 0000000..b866527 --- /dev/null +++ b/src/test/scala/scala/async/run/late/LateExpansion.scala @@ -0,0 +1,170 @@ +package scala.async.run.late + +import java.io.File + +import junit.framework.Assert.assertEquals +import org.junit.Test + +import scala.annotation.StaticAnnotation +import scala.async.internal.{AsyncId, AsyncMacro} +import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader +import scala.tools.nsc._ +import scala.tools.nsc.plugins.{Plugin, PluginComponent} +import scala.tools.nsc.reporters.StoreReporter +import scala.tools.nsc.transform.TypingTransformers + +// Tests for customized use of the async transform from a compiler plugin, which +// calls it from a new phase that runs after patmat. +class LateExpansion { + @Test def test0(): Unit = { + val result = wrapAndRun( + """ + | @autoawait def id(a: String) = a + | id("foo") + id("bar") + | """.stripMargin) + assertEquals("foobar", result) + } + @Test def testGuard(): Unit = { + val result = wrapAndRun( + """ + | @autoawait def id[A](a: A) = a + | "" match { case _ if id(false) => ???; case _ => "okay" } + | """.stripMargin) + assertEquals("okay", result) + } + + @Test def testExtractor(): Unit = { + val result = wrapAndRun( + """ + | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } + | "" match { case Extractor(a, b) if "".isEmpty => a == b } + | """.stripMargin) + assertEquals(true, result) + } + + @Test def testNestedMatchExtractor(): Unit = { + val result = wrapAndRun( + """ + | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } + | "" match { + | case _ if "".isEmpty => + | "" match { case Extractor(a, b) => a == b } + | } + | """.stripMargin) + assertEquals(true, result) + } + + @Test def testCombo(): Unit = { + val result = wrapAndRun( + """ + | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) } + | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) } + | @autoawait def id(a: String) = a + | println("Test.test") + | val r1 = Predef.identity("blerg") match { + | case x if " ".isEmpty => "case 2: " + x + | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y + | x match { + | case Extractor1(Extractor2(x), y: String) => + | case _ => + | } + | case Extractor2(x) => "case 3: " + x + | } + | r1 + | """.stripMargin) + assertEquals("case 3: blerg3", result) + } + + def wrapAndRun(code: String): Any = { + run( + s""" + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | @lateasync + | def test: Any = { + | $code + | } + |} + | """.stripMargin) + } + + def run(code: String): Any = { + val reporter = new StoreReporter + val settings = new Settings(println(_)) + settings.outdir.value = sys.props("java.io.tmpdir") + settings.embeddedDefaults(getClass.getClassLoader) + val isInSBT = !settings.classpath.isSetByUser + if (isInSBT) settings.usejavacp.value = true + val global = new Global(settings, reporter) { + self => + + object late extends { + val global: self.type = self + } with LatePlugin + + override protected def loadPlugins(): List[Plugin] = late :: Nil + } + import global._ + + val run = new Run + val source = newSourceFile(code) + run.compileSources(source :: Nil) + assert(!reporter.hasErrors, reporter.infos.mkString("\n")) + val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) + val cls = loader.loadClass("Test") + cls.getMethod("test").invoke(null) + } +} + +abstract class LatePlugin extends Plugin { + import global._ + + override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { + val global: LatePlugin.this.global.type = LatePlugin.this.global + + lazy val asyncIdSym = symbolOf[AsyncId.type] + lazy val asyncSym = asyncIdSym.info.member(TermName("async")) + lazy val awaitSym = asyncIdSym.info.member(TermName("await")) + lazy val autoAwaitSym = symbolOf[autoawait] + lazy val lateAsyncSym = symbolOf[lateasync] + + def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) { + override def transform(tree: Tree): Tree = { + super.transform(tree) match { + case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => + localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) + case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { + val expandee = localTyper.context.withMacrosDisabled( + localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs))) + ) + val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee) + val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs) + val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any]) + deriveDefDef(dd)(_ => localTyper.typed(code)) + } + case x => x + } + } + } + + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { + override def apply(unit: CompilationUnit): Unit = { + val translated = newTransformer(unit).transformUnit(unit) + //println(show(unit.body)) + translated + } + } + + override val runsAfter: List[String] = "patmat" :: Nil + override val phaseName: String = "postpatmat" + + }) + override val description: String = "postpatmat" + override val name: String = "postpatmat" +} + +// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }` +final class lateasync extends StaticAnnotation + +// Calls to methods with this annotation are translated to `AsyncId.await(<call>)` +final class autoawait extends StaticAnnotation |