aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/scala/async/run/late/LateExpansion.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/scala/scala/async/run/late/LateExpansion.scala')
-rw-r--r--src/test/scala/scala/async/run/late/LateExpansion.scala170
1 files changed, 170 insertions, 0 deletions
diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala
new file mode 100644
index 0000000..b866527
--- /dev/null
+++ b/src/test/scala/scala/async/run/late/LateExpansion.scala
@@ -0,0 +1,170 @@
+package scala.async.run.late
+
+import java.io.File
+
+import junit.framework.Assert.assertEquals
+import org.junit.Test
+
+import scala.annotation.StaticAnnotation
+import scala.async.internal.{AsyncId, AsyncMacro}
+import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
+import scala.tools.nsc._
+import scala.tools.nsc.plugins.{Plugin, PluginComponent}
+import scala.tools.nsc.reporters.StoreReporter
+import scala.tools.nsc.transform.TypingTransformers
+
+// Tests for customized use of the async transform from a compiler plugin, which
+// calls it from a new phase that runs after patmat.
+class LateExpansion {
+ @Test def test0(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id(a: String) = a
+ | id("foo") + id("bar")
+ | """.stripMargin)
+ assertEquals("foobar", result)
+ }
+ @Test def testGuard(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id[A](a: A) = a
+ | "" match { case _ if id(false) => ???; case _ => "okay" }
+ | """.stripMargin)
+ assertEquals("okay", result)
+ }
+
+ @Test def testExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match { case Extractor(a, b) if "".isEmpty => a == b }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testNestedMatchExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match {
+ | case _ if "".isEmpty =>
+ | "" match { case Extractor(a, b) => a == b }
+ | }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testCombo(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) }
+ | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) }
+ | @autoawait def id(a: String) = a
+ | println("Test.test")
+ | val r1 = Predef.identity("blerg") match {
+ | case x if " ".isEmpty => "case 2: " + x
+ | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y
+ | x match {
+ | case Extractor1(Extractor2(x), y: String) =>
+ | case _ =>
+ | }
+ | case Extractor2(x) => "case 3: " + x
+ | }
+ | r1
+ | """.stripMargin)
+ assertEquals("case 3: blerg3", result)
+ }
+
+ def wrapAndRun(code: String): Any = {
+ run(
+ s"""
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | @lateasync
+ | def test: Any = {
+ | $code
+ | }
+ |}
+ | """.stripMargin)
+ }
+
+ def run(code: String): Any = {
+ val reporter = new StoreReporter
+ val settings = new Settings(println(_))
+ settings.outdir.value = sys.props("java.io.tmpdir")
+ settings.embeddedDefaults(getClass.getClassLoader)
+ val isInSBT = !settings.classpath.isSetByUser
+ if (isInSBT) settings.usejavacp.value = true
+ val global = new Global(settings, reporter) {
+ self =>
+
+ object late extends {
+ val global: self.type = self
+ } with LatePlugin
+
+ override protected def loadPlugins(): List[Plugin] = late :: Nil
+ }
+ import global._
+
+ val run = new Run
+ val source = newSourceFile(code)
+ run.compileSources(source :: Nil)
+ assert(!reporter.hasErrors, reporter.infos.mkString("\n"))
+ val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
+ val cls = loader.loadClass("Test")
+ cls.getMethod("test").invoke(null)
+ }
+}
+
+abstract class LatePlugin extends Plugin {
+ import global._
+
+ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
+ val global: LatePlugin.this.global.type = LatePlugin.this.global
+
+ lazy val asyncIdSym = symbolOf[AsyncId.type]
+ lazy val asyncSym = asyncIdSym.info.member(TermName("async"))
+ lazy val awaitSym = asyncIdSym.info.member(TermName("await"))
+ lazy val autoAwaitSym = symbolOf[autoawait]
+ lazy val lateAsyncSym = symbolOf[lateasync]
+
+ def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) {
+ override def transform(tree: Tree): Tree = {
+ super.transform(tree) match {
+ case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
+ case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
+ val expandee = localTyper.context.withMacrosDisabled(
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs)))
+ )
+ val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee)
+ val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs)
+ val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any])
+ deriveDefDef(dd)(_ => localTyper.typed(code))
+ }
+ case x => x
+ }
+ }
+ }
+
+ override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
+ override def apply(unit: CompilationUnit): Unit = {
+ val translated = newTransformer(unit).transformUnit(unit)
+ //println(show(unit.body))
+ translated
+ }
+ }
+
+ override val runsAfter: List[String] = "patmat" :: Nil
+ override val phaseName: String = "postpatmat"
+
+ })
+ override val description: String = "postpatmat"
+ override val name: String = "postpatmat"
+}
+
+// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }`
+final class lateasync extends StaticAnnotation
+
+// Calls to methods with this annotation are translated to `AsyncId.await(<call>)`
+final class autoawait extends StaticAnnotation