package scala.async.run.late import java.io.File import junit.framework.Assert.assertEquals import org.junit.{Assert, Test} import scala.annotation.StaticAnnotation import scala.annotation.meta.{field, getter} import scala.async.TreeInterrogation import scala.async.internal.AsyncId import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader import scala.tools.nsc._ import scala.tools.nsc.plugins.{Plugin, PluginComponent} import scala.tools.nsc.reporters.StoreReporter import scala.tools.nsc.transform.TypingTransformers // Tests for customized use of the async transform from a compiler plugin, which // calls it from a new phase that runs after patmat. class LateExpansion { @Test def test0(): Unit = { val result = wrapAndRun( """ | @autoawait def id(a: String) = a | id("foo") + id("bar") | """.stripMargin) assertEquals("foobar", result) } @Test def testGuard(): Unit = { val result = wrapAndRun( """ | @autoawait def id[A](a: A) = a | "" match { case _ if id(false) => ???; case _ => "okay" } | """.stripMargin) assertEquals("okay", result) } @Test def testExtractor(): Unit = { val result = wrapAndRun( """ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } | "" match { case Extractor(a, b) if "".isEmpty => a == b } | """.stripMargin) assertEquals(true, result) } @Test def testNestedMatchExtractor(): Unit = { val result = wrapAndRun( """ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } | "" match { | case _ if "".isEmpty => | "" match { case Extractor(a, b) => a == b } | } | """.stripMargin) assertEquals(true, result) } @Test def testCombo(): Unit = { val result = wrapAndRun( """ | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) } | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) } | @autoawait def id(a: String) = a | println("Test.test") | val r1 = Predef.identity("blerg") match { | case x if " ".isEmpty => "case 2: " + x | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y | x match { | case Extractor1(Extractor2(x), y: String) => | case _ => | } | case Extractor2(x) => "case 3: " + x | } | r1 | """.stripMargin) assertEquals("case 3: blerg3", result) } @Test def polymorphicMethod(): Unit = { val result = run( """ |import scala.async.run.late.{autoawait,lateasync} |object Test { | class C { override def toString = "C" } | @autoawait def foo[A <: C](a: A): A = a | @lateasync | def test1[CC <: C](c: CC): (CC, CC) = { | val x: (CC, CC) = 0 match { case _ if false => ???; case _ => (foo(c), foo(c)) } | x | } | def test(): (C, C) = test1(new C) |} | """.stripMargin) assertEquals("(C,C)", result.toString) } @Test def shadowing(): Unit = { val result = run( """ |import scala.async.run.late.{autoawait,lateasync} |object Test { | trait Foo | trait Bar extends Foo | @autoawait def boundary = "" | @lateasync | def test: Unit = { | (new Bar {}: Any) match { | case foo: Bar => | boundary | 0 match { | case _ => foo; () | } | () | } | () | } |} | """.stripMargin) } @Test def shadowing0(): Unit = { val result = run( """ |import scala.async.run.late.{autoawait,lateasync} |object Test { | trait Foo | trait Bar | def test: Any = test(new C) | @autoawait def asyncBoundary: String = "" | @lateasync | def test(foo: Foo): Foo = foo match { | case foo: Bar => | val foo2: Foo with Bar = new Foo with Bar {} | asyncBoundary | null match { | case _ => foo2 | } | case other => foo | } | class C extends Foo with Bar |} | """.stripMargin) } @Test def shadowing2(): Unit = { val result = run( """ |import scala.async.run.late.{autoawait,lateasync} |object Test { | trait Base; trait Foo[T <: Base] { @autoawait def func: Option[Foo[T]] = None } | class Sub extends Base | trait Bar extends Foo[Sub] | def test: Any = test(new Bar {}) | @lateasync | def test[T <: Base](foo: Foo[T]): Foo[T] = foo match { | case foo: Bar => | val res = foo.func | res match { | case _ => | } | foo | case other => foo | } | test(new Bar {}) |} | """.stripMargin) } @Test def patternAlternative(): Unit = { val result = wrapAndRun( """ | @autoawait def one = 1 | | @lateasync def test = { | Option(true) match { | case null | None => false | case Some(v) => one; v | } | } | """.stripMargin) } @Test def patternAlternativeBothAnnotations(): Unit = { val result = wrapAndRun( """ |import scala.async.run.late.{autoawait,lateasync} |object Test { | @autoawait def func1() = "hello" | @lateasync def func(a: Option[Boolean]) = a match { | case null | None => func1 + " world" | case _ => "okay" | } | def test: Any = func(None) |} | """.stripMargin) } @Test def shadowingRefinedTypes(): Unit = { val result = run( s""" |import scala.async.run.late.{autoawait,lateasync} |trait Base |class Sub extends Base |trait Foo[T <: Base] { | @autoawait def func: Option[Foo[T]] = None |} |trait Bar extends Foo[Sub] |object Test { | @lateasync def func[T <: Base](foo: Foo[T]): Foo[T] = foo match { // the whole pattern match will be wrapped with async{ } | case foo: Bar => | val res = foo.func // will be rewritten into: await(foo.func) | res match { | case Some(v) => v // this will report type mismtach | case other => foo | } | case other => foo | } | def test: Any = { val b = new Bar{}; func(b) == b } |}""".stripMargin) assertEquals(true, result) } @Test def testMatchEndIssue(): Unit = { val result = run( """ |import scala.async.run.late.{autoawait,lateasync} |sealed trait Subject |final class Principal(val name: String) extends Subject |object Principal { | def unapply(p: Principal): Option[String] = Some(p.name) |} |object Test { | @autoawait @lateasync | def containsPrincipal(search: String, value: Subject): Boolean = value match { | case Principal(name) if name == search => true | case Principal(name) => containsPrincipal(search, value) | case other => false | } | | @lateasync | def test = containsPrincipal("test", new Principal("test")) |} | """.stripMargin) } @Test def testGenericTypeBoundaryIssue(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} trait InstrumentOfValue trait Security[T <: InstrumentOfValue] extends InstrumentOfValue class Bound extends Security[Bound] class Futures extends Security[Futures] object TestGenericTypeBoundIssue { @autoawait @lateasync def processBound(bound: Bound): Unit = { println("process Bound") } @autoawait @lateasync def processFutures(futures: Futures): Unit = { println("process Futures") } @autoawait @lateasync def doStuff(sec: Security[_]): Unit = { sec match { case bound: Bound => processBound(bound) case futures: Futures => processFutures(futures) case _ => throw new Exception("Unknown Security type: " + sec) } } } """.stripMargin) } @Test def testReturnTupleIssue(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} class TestReturnExprIssue(str: String) { @autoawait @lateasync def getTestValue = Some(42) @autoawait @lateasync def doStuff: Int = { val opt: Option[Int] = getTestValue // here we have an async method invoke opt match { case Some(li) => li // use the result somehow case None => } 42 // type mismatch; found : AnyVal required: Int } } """.stripMargin) } @Test def testAfterRefchecksIssue(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} trait Factory[T] { def create: T } sealed trait TimePoint class TimeLine[TP <: TimePoint](val tpInitial: Factory[TP]) { @autoawait @lateasync private[TimeLine] val tp: TP = tpInitial.create @autoawait @lateasync def timePoint: TP = tp } object Test { def test: Unit = () } """) } @Test def testArrayIndexOutOfBoundIssue(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} sealed trait Result case object A extends Result case object B extends Result case object C extends Result object Test { protected def doStuff(res: Result) = { class C { @autoawait def needCheck = false @lateasync def m = { if (needCheck) "NO" else { res match { case A => 1 case _ => 2 } } } } } @lateasync def test() = doStuff(B) } """) } def wrapAndRun(code: String): Any = { run( s""" |import scala.async.run.late.{autoawait,lateasync} |object Test { | @lateasync | def test: Any = { | $code | } |} | """.stripMargin) } @Test def testNegativeArraySizeException(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} object Test { def foo(foo: Any, bar: Any) = () @autoawait def getValue = 4.2 @lateasync def func(f: Any) = { foo(f match { case _ if "".isEmpty => 2 }, getValue); } @lateasync def test() = func(4) } """) } @Test def testNegativeArraySizeExceptionFine1(): Unit = { val result = run( """ import scala.async.run.late.{autoawait,lateasync} case class FixedFoo(foo: Int) class Foobar(val foo: Int, val bar: Double) { @autoawait @lateasync def getValue = 4.2 @autoawait @lateasync def func(f: Any) = { new Foobar(foo = f match { case FixedFoo(x) => x case _ => 2 }, bar = getValue) } } object Test { @lateasync def test() = new Foobar(0, 0).func(4) } """) } def run(code: String): Any = { val reporter = new StoreReporter val settings = new Settings(println(_)) // settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn") settings.outdir.value = sys.props("java.io.tmpdir") settings.embeddedDefaults(getClass.getClassLoader) val isInSBT = !settings.classpath.isSetByUser if (isInSBT) settings.usejavacp.value = true val global = new Global(settings, reporter) { self => object late extends { val global: self.type = self } with LatePlugin override protected def loadPlugins(): List[Plugin] = late :: Nil } import global._ val run = new Run val source = newSourceFile(code) // TreeInterrogation.withDebug { run.compileSources(source :: Nil) // } Assert.assertTrue(reporter.infos.mkString("\n"), !reporter.hasErrors) val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) val cls = try { loader.loadClass("Test") } catch { case ex: ClassNotFoundException => throw new ClassNotFoundException(new File(settings.outdir.value).list().mkString(", "), ex) } cls.getMethod("test").invoke(null) } } abstract class LatePlugin extends Plugin { import global._ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { val global: LatePlugin.this.global.type = LatePlugin.this.global lazy val asyncIdSym = symbolOf[AsyncId.type] lazy val asyncSym = asyncIdSym.info.member(TermName("async")) lazy val awaitSym = asyncIdSym.info.member(TermName("await")) lazy val autoAwaitSym = symbolOf[autoawait] lazy val lateAsyncSym = symbolOf[lateasync] def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) { override def transform(tree: Tree): Tree = { super.transform(tree) match { case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil)) case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { deriveDefDef(dd){ rhs: Tree => val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) localTyper.typed(atPos(dd.pos)(invoke)) } } case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) { deriveValDef(vd){ rhs: Tree => val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) localTyper.typed(atPos(vd.pos)(invoke)) } } case vd: ValDef => vd case x => x } } } override def newPhase(prev: Phase): Phase = new StdPhase(prev) { override def apply(unit: CompilationUnit): Unit = { val translated = newTransformer(unit).transformUnit(unit) //println(show(unit.body)) translated } } override val runsAfter: List[String] = "refchecks" :: Nil override val phaseName: String = "postpatmat" }) override val description: String = "postpatmat" override val name: String = "postpatmat" } // Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { }` @field final class lateasync extends StaticAnnotation // Calls to methods with this annotation are translated to `AsyncId.await()` @getter final class autoawait extends StaticAnnotation