aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/scala/async/run/late/LateExpansion.scala
blob: b8665271057f8fd1a9c458ccb3961f7a5af52927 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package scala.async.run.late

import java.io.File

import junit.framework.Assert.assertEquals
import org.junit.Test

import scala.annotation.StaticAnnotation
import scala.async.internal.{AsyncId, AsyncMacro}
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
import scala.tools.nsc._
import scala.tools.nsc.plugins.{Plugin, PluginComponent}
import scala.tools.nsc.reporters.StoreReporter
import scala.tools.nsc.transform.TypingTransformers

// Tests for customized use of the async transform from a compiler plugin, which
// calls it from a new phase that runs after patmat.
class LateExpansion {
  @Test def test0(): Unit = {
    val result = wrapAndRun(
      """
        | @autoawait def id(a: String) = a
        | id("foo") + id("bar")
        | """.stripMargin)
    assertEquals("foobar", result)
  }
  @Test def testGuard(): Unit = {
    val result = wrapAndRun(
      """
        | @autoawait def id[A](a: A) = a
        | "" match { case _ if id(false) => ???; case _ => "okay" }
        | """.stripMargin)
    assertEquals("okay", result)
  }

  @Test def testExtractor(): Unit = {
    val result = wrapAndRun(
      """
        | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
        | "" match { case Extractor(a, b) if "".isEmpty => a == b }
        | """.stripMargin)
    assertEquals(true, result)
  }

  @Test def testNestedMatchExtractor(): Unit = {
    val result = wrapAndRun(
      """
        | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
        | "" match {
        |   case _ if "".isEmpty =>
        |     "" match { case Extractor(a, b) => a == b }
        | }
        | """.stripMargin)
    assertEquals(true, result)
  }

  @Test def testCombo(): Unit = {
    val result = wrapAndRun(
      """
        | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) }
        | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) }
        | @autoawait def id(a: String) = a
        | println("Test.test")
        | val r1 = Predef.identity("blerg") match {
        |   case x if " ".isEmpty => "case 2: " + x
        |   case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y
        |     x match {
        |       case Extractor1(Extractor2(x), y: String) =>
        |       case _ =>
        |     }
        |   case Extractor2(x) => "case 3: " + x
        | }
        | r1
        | """.stripMargin)
    assertEquals("case 3: blerg3", result)
  }

  def wrapAndRun(code: String): Any = {
    run(
      s"""
         |import scala.async.run.late.{autoawait,lateasync}
         |object Test {
         |  @lateasync
         |  def test: Any = {
         |    $code
         |  }
         |}
         | """.stripMargin)
  }

  def run(code: String): Any = {
    val reporter = new StoreReporter
    val settings = new Settings(println(_))
    settings.outdir.value = sys.props("java.io.tmpdir")
    settings.embeddedDefaults(getClass.getClassLoader)
    val isInSBT = !settings.classpath.isSetByUser
    if (isInSBT) settings.usejavacp.value = true
    val global = new Global(settings, reporter) {
      self =>

      object late extends {
        val global: self.type = self
      } with LatePlugin

      override protected def loadPlugins(): List[Plugin] = late :: Nil
    }
    import global._

    val run = new Run
    val source = newSourceFile(code)
    run.compileSources(source :: Nil)
    assert(!reporter.hasErrors, reporter.infos.mkString("\n"))
    val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
    val cls = loader.loadClass("Test")
    cls.getMethod("test").invoke(null)
  }
}

abstract class LatePlugin extends Plugin {
  import global._

  override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
    val global: LatePlugin.this.global.type = LatePlugin.this.global

    lazy val asyncIdSym = symbolOf[AsyncId.type]
    lazy val asyncSym = asyncIdSym.info.member(TermName("async"))
    lazy val awaitSym = asyncIdSym.info.member(TermName("await"))
    lazy val autoAwaitSym = symbolOf[autoawait]
    lazy val lateAsyncSym = symbolOf[lateasync]

    def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) {
      override def transform(tree: Tree): Tree = {
        super.transform(tree) match {
          case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
            localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
          case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
            val expandee = localTyper.context.withMacrosDisabled(
              localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs)))
            )
            val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee)
            val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs)
            val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any])
            deriveDefDef(dd)(_ => localTyper.typed(code))
          }
          case x => x
        }
      }
    }

    override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
      override def apply(unit: CompilationUnit): Unit = {
        val translated = newTransformer(unit).transformUnit(unit)
        //println(show(unit.body))
        translated
      }
    }

    override val runsAfter: List[String] = "patmat" :: Nil
    override val phaseName: String = "postpatmat"

  })
  override val description: String = "postpatmat"
  override val name: String = "postpatmat"
}

// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }`
final class lateasync extends StaticAnnotation

// Calls to methods with this annotation are translated to `AsyncId.await(<call>)`
final class autoawait extends StaticAnnotation