diff options
60 files changed, 3192 insertions, 2370 deletions
@@ -3,3 +3,4 @@ target .idea .idea_modules *.icode +project/local.sbt
\ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..7fb860d --- /dev/null +++ b/.travis.yml @@ -0,0 +1,13 @@ +language: scala +script: + - sbt ++$TRAVIS_SCALA_VERSION clean test publishLocal +scala: + - 2.10.3 + - 2.11.0-M6 +jdk: + - openjdk6 + - openjdk7 +notifications: + email: + - jason.zaugg@typesafe.com + - philipp.haller@typesafe.com
\ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..00f617f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,192 @@ +## Building + +The async macro and its test suite can be built and run with SBT. + +## Contributing + +If you are interested in contributing code, we ask you to complete and submit +to us the Scala Contributor License Agreement, which allows us to ensure that +all code submitted to the project is unencumbered by copyrights or patents. +The form is available at: +http://www.scala-lang.org/sites/default/files/contributor_agreement.pdf + +Before submitting a pull-request, please make sure you have followed the guidelines +outlined in our [Pull Request Policy](https://github.com/scala/scala/wiki/Pull-Request-Policy). + +## Troubleshooting + - Logging of the transform can be enabled with `scalac -Dscala.async.debug=true`. + - Tracing of the ANF transform: `scalac -Dscala.async.trace=true` + - Debug the macro expansion by checking out the project and executing the application + [`scala.async.TreeInterrogation`](https://github.com/scala/async/blob/master/src/test/scala/scala/async/TreeInterrogation.scala#L59) + +## Generated Code examples + +```scala +val future = async { + val f1 = async { true } + val x = 1 + def inc(t: Int) = t + x + val t = 0 + val f2 = async { 42 } + if (await(f1)) await(f2) else { val z = 1; inc(t + z) } +} +``` + +After ANF transform. + + - await calls are moved to only appear on the RHS of a value definition. + - `if` is not used as an expression, instead each branch writes its result + to a synthetic `var`. + +```scala + { + (); + val f1: scala.concurrent.Future[Boolean] = { + scala.concurrent.Future.apply[Boolean](true)(scala.concurrent.ExecutionContext.Implicits.global) + }; + val x: Int = 1; + def inc(t: Int): Int = t.+(x); + val t: Int = 0; + val f2: scala.concurrent.Future[Int] = { + scala.concurrent.Future.apply[Int](42)(scala.concurrent.ExecutionContext.Implicits.global) + }; + val await$1: Boolean = scala.async.Async.await[Boolean](f1); + var ifres$1: Int = 0; + if (await$1) + { + val await$2: Int = scala.async.Async.await[Int](f2); + ifres$1 = await$2 + } + else + { + ifres$1 = { + val z: Int = 1; + inc(t.+(z)) + } + }; + ifres$1 +} +``` + +After async transform: + + - one class synthesized to act as the state machine. It's `apply()` method will + be used to start the computation (even the code before the first await call + is executed asynchronously), and the `apply(tr: scala.util.Try[Any])` method + will continue after each completed background task. + - each chunk of code moved into the a branch of the pattern match in `resume$async`. + - value and method definitions accessed from multiple states are lifted to be + members of `class stateMachine`. Others remain local, e.g. `val z`. + +```scala + { + class stateMachine$7 extends ... { + def <init>() = { + super.<init>(); + () + }; + var state$async: Int = 0; + val result$async: scala.concurrent.Promise[Int] = scala.concurrent.Promise.apply[Int](); + val execContext$async = scala.concurrent.ExecutionContext.Implicits.global; + var x$1: Int = 0; + def inc$1(t: Int): Int = t.$plus(x$1); + var t$1: Int = 0; + var f2$1: scala.concurrent.Future[Int] = null; + var await$1: Boolean = false; + var ifres$1: Int = 0; + var await$2: Int = 0; + def resume$async(): Unit = try { + state$async match { + case 0 => { + (); + val f1 = { + scala.concurrent.Future.apply[Boolean](true)(scala.concurrent.ExecutionContext.Implicits.global) + }; + x$1 = 1; + t$1 = 0; + f2$1 = { + scala.concurrent.Future.apply[Int](42)(scala.concurrent.ExecutionContext.Implicits.global) + }; + f1.onComplete(this)(execContext$async) + } + case 1 => { + ifres$1 = 0; + if (await$1) + { + state$async = 2; + resume$async() + } + else + { + state$async = 3; + resume$async() + } + } + case 2 => { + f2$1.onComplete(this)(execContext$async); + () + } + case 5 => { + ifres$1 = await$2; + state$async = 4; + resume$async() + } + case 3 => { + ifres$1 = { + val z = 1; + inc$1(t$1.$plus(z)) + }; + state$async = 4; + resume$async() + } + case 4 => { + result$async.complete(scala.util.Success.apply(ifres$1)); + () + } + } + } catch { + case NonFatal((tr @ _)) => { + { + result$async.complete(scala.util.Failure.apply(tr)); + () + }; + () + } + }; + def apply(tr: scala.util.Try[Any]): Unit = state$async match { + case 0 => { + if (tr.isFailure) + { + result$async.complete(tr.asInstanceOf[scala.util.Try[Int]]); + () + } + else + { + await$1 = tr.get.asInstanceOf[Boolean]; + state$async = 1; + resume$async() + }; + () + } + case 2 => { + if (tr.isFailure) + { + result$async.complete(tr.asInstanceOf[scala.util.Try[Int]]); + () + } + else + { + await$2 = tr.get.asInstanceOf[Int]; + state$async = 5; + resume$async() + }; + () + } + }; + def apply: Unit = resume$async() + }; + val stateMachine$7: StateMachine[scala.concurrent.Promise[Int], scala.concurrent.ExecutionContext] = new stateMachine$7(); + scala.concurrent.Future.apply(stateMachine$7.apply())(scala.concurrent.ExecutionContext.Implicits.global); + stateMachine$7.result$async.future +} +```
\ No newline at end of file @@ -1,35 +1,28 @@ -SCALA LICENSE +Copyright (c) 2012-2013 EPFL +Copyright (c) 2012-2013 Typesafe, Inc. -Copyright (c) 2002-2012 EPFL, Lausanne, unless otherwise specified. All rights reserved. -This software was developed by the Programming Methods Laboratory of the -Swiss Federal Institute of Technology (EPFL), Lausanne, Switzerland. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the EPFL nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the EPFL nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -1,9 +1,18 @@ # Scala Async Project +[![Build Status](https://secure.travis-ci.org/scala/async.png)](http://travis-ci.org/scala/async) + + ## Quick start - - Add `scala-async.jar` to your classpath - - Use Scala 2.10.0 +Add a dependency: + +```scala +// SBT +libraryDependencies += "org.scala-lang.modules" %% "scala-async" % "0.9.0-M4" +``` + +Write your first `async` block: ```scala import ExecutionContext.Implicits.global @@ -132,12 +141,6 @@ difficult to understand. b) completes the result Promise of the async block, if at the terminal state. - an `apply(): Unit` method that starts the computation. -## Troubleshooting - - Logging of the transform can be enabled with `scalac -Dscala.async.debug=true`. - - Tracing of the ANF transform: `scalac -Dscala.async.trace=true` - - Debug the macro expansion by checking out the project and executing the application - [`TreeInterrogation`](https://github.com/scala/async/blob/master/src/test/scala/scala/async/TreeInterrogation.scala#L59) - ## Limitations - See the [neg](https://github.com/scala/async/tree/master/src/test/scala/scala/async/neg) test cases for for constructs that are not allowed in a async block @@ -145,190 +148,3 @@ difficult to understand. to be dropped in the next milestone. - See [#13](https://github.com/scala/async/issues/13) for why `await` is not possible in closures, and for suggestions on ways to structure the code to work around this limitation. - -## Building - -The async macro and its test suite can be built and run with SBT. - -## Contributing - -If you are interested in contributing code, we ask you to complete and submit -to us the Scala Contributor License Agreement, which allows us to ensure that -all code submitted to the project is unencumbered by copyrights or patents. -The form is available at: -http://www.scala-lang.org/sites/default/files/contributor_agreement.pdf - -Before submitting a pull-request, please make sure you have followed the guidelines -outlined in our [Pull Request Policy](https://github.com/scala/scala/wiki/Pull-Request-Policy). - -### Generated Code examples - -```scala -val future = async { - val f1 = async { true } - val x = 1 - def inc(t: Int) = t + x - val t = 0 - val f2 = async { 42 } - if (await(f1)) await(f2) else { val z = 1; inc(t + z) } -} -``` - -After ANF transform. - - - await calls are moved to only appear on the RHS of a value definition. - - `if` is not used as an expression, instead each branch writes its result - to a synthetic `var`. - -```scala - { - (); - val f1: scala.concurrent.Future[Boolean] = { - scala.concurrent.Future.apply[Boolean](true)(scala.concurrent.ExecutionContext.Implicits.global) - }; - val x: Int = 1; - def inc(t: Int): Int = t.+(x); - val t: Int = 0; - val f2: scala.concurrent.Future[Int] = { - scala.concurrent.Future.apply[Int](42)(scala.concurrent.ExecutionContext.Implicits.global) - }; - val await$1: Boolean = scala.async.Async.await[Boolean](f1); - var ifres$1: Int = 0; - if (await$1) - { - val await$2: Int = scala.async.Async.await[Int](f2); - ifres$1 = await$2 - } - else - { - ifres$1 = { - val z: Int = 1; - inc(t.+(z)) - } - }; - ifres$1 -} -``` - -After async transform: - - - one class synthesized to act as the state machine. It's `apply()` method will - be used to start the computation (even the code before the first await call - is executed asynchronously), and the `apply(tr: scala.util.Try[Any])` method - will continue after each completed background task. - - each chunk of code moved into the a branch of the pattern match in `resume$async`. - - value and method definitions accessed from multiple states are lifted to be - members of `class stateMachine`. Others remain local, e.g. `val z`. - -```scala - { - class stateMachine$7 extends StateMachine[scala.concurrent.Promise[Int], scala.concurrent.ExecutionContext] { - def <init>() = { - super.<init>(); - () - }; - var state$async: Int = 0; - val result$async: scala.concurrent.Promise[Int] = scala.concurrent.Promise.apply[Int](); - val execContext$async = scala.concurrent.ExecutionContext.Implicits.global; - var x$1: Int = 0; - def inc$1(t: Int): Int = t.$plus(x$1); - var t$1: Int = 0; - var f2$1: scala.concurrent.Future[Int] = null; - var await$1: Boolean = false; - var ifres$1: Int = 0; - var await$2: Int = 0; - def resume$async(): Unit = try { - state$async match { - case 0 => { - (); - val f1 = { - scala.concurrent.Future.apply[Boolean](true)(scala.concurrent.ExecutionContext.Implicits.global) - }; - x$1 = 1; - t$1 = 0; - f2$1 = { - scala.concurrent.Future.apply[Int](42)(scala.concurrent.ExecutionContext.Implicits.global) - }; - f1.onComplete(this)(execContext$async) - } - case 1 => { - ifres$1 = 0; - if (await$1) - { - state$async = 2; - resume$async() - } - else - { - state$async = 3; - resume$async() - } - } - case 2 => { - f2$1.onComplete(this)(execContext$async); - () - } - case 5 => { - ifres$1 = await$2; - state$async = 4; - resume$async() - } - case 3 => { - ifres$1 = { - val z = 1; - inc$1(t$1.$plus(z)) - }; - state$async = 4; - resume$async() - } - case 4 => { - result$async.complete(scala.util.Success.apply(ifres$1)); - () - } - } - } catch { - case NonFatal((tr @ _)) => { - { - result$async.complete(scala.util.Failure.apply(tr)); - () - }; - () - } - }; - def apply(tr: scala.util.Try[Any]): Unit = state$async match { - case 0 => { - if (tr.isFailure) - { - result$async.complete(tr.asInstanceOf[scala.util.Try[Int]]); - () - } - else - { - await$1 = tr.get.asInstanceOf[Boolean]; - state$async = 1; - resume$async() - }; - () - } - case 2 => { - if (tr.isFailure) - { - result$async.complete(tr.asInstanceOf[scala.util.Try[Int]]); - () - } - else - { - await$2 = tr.get.asInstanceOf[Int]; - state$async = 5; - resume$async() - }; - () - } - }; - def apply: Unit = resume$async() - }; - val stateMachine$7: StateMachine[scala.concurrent.Promise[Int], scala.concurrent.ExecutionContext] = new stateMachine$7(); - scala.concurrent.Future.apply(stateMachine$7.apply())(scala.concurrent.ExecutionContext.Implicits.global); - stateMachine$7.result$async.future -} -``` @@ -1,38 +1,54 @@ -scalaVersion := "2.10.1" +scalaVersion := "2.10.3" -organization := "org.typesafe.async" // TODO new org name under scala-lang. +// Uncomment to test with a locally built copy of Scala. +// scalaHome := Some(file("/code/scala2/build/pack")) + + +organization := "org.scala-lang.modules" name := "scala-async" -version := "1.0.0-SNAPSHOT" +version := "0.9.0-SNAPSHOT" libraryDependencies <++= (scalaVersion) { sv => Seq( + // TODO we should make this provided after we rely on @compileTimeOnly in scla-library in 2.11.- + // but if we do that now, and a user doesn't have this on the classpath, they can get the + // dreaded MissingRequirementErrors when unpickling types from scala.async.Async "org.scala-lang" % "scala-reflect" % sv, - "org.scala-lang" % "scala-compiler" % sv % "test" + "org.scala-lang" % "scala-compiler" % sv % "provided" ) } libraryDependencies += "junit" % "junit-dep" % "4.10" % "test" -libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M2" % "test" +libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test" testOptions += Tests.Argument(TestFrameworks.JUnit, "-q", "-v", "-s") parallelExecution in Global := false -autoCompilerPlugins := true +scalacOptions in compile ++= Seq("-optimize", "-deprecation", "-unchecked", "-Xlint", "-feature") -libraryDependencies <<= (scalaVersion, libraryDependencies) { - (ver, deps) => - deps :+ compilerPlugin("org.scala-lang.plugins" % "continuations" % ver) +scalacOptions in Test ++= Seq("-Yrangepos") + +// Generate $name.properties to store our version as well as the scala version used to build +resourceGenerators in Compile <+= Def.task { + val props = new java.util.Properties + props.put("version.number", version.value) + props.put("scala.version.number", scalaVersion.value) + props.put("scala.binary.version.number", scalaBinaryVersion.value) + val file = (resourceManaged in Compile).value / s"${name.value}.properties" + IO.write(props, null, file) + Seq(file) } -scalacOptions += "-P:continuations:enable" +mappings in (Compile, packageBin) += { + (baseDirectory.value / s"${name.value}.properties") -> s"${name.value}.properties" +} -scalacOptions ++= Seq("-deprecation", "-unchecked", "-Xlint", "-feature") -description := "An asynchronous programming facility for Scala, in the spirit of C# await/async" +description := "An asynchronous programming facility for Scala that offers a direct API for working with Futures." homepage := Some(url("http://github.com/scala/async")) @@ -40,6 +56,25 @@ startYear := Some(2012) licenses +=("Scala license", url("https://github.com/scala/async/blob/master/LICENSE")) +// Uncomment to disable test compilation. +// (sources in Test) ~= ((xs: Seq[File]) => xs.filter(f => Seq("TreeInterrogation", "package").exists(f.name.contains))) + +// maven publishing +publishTo := { + val nexus = "https://oss.sonatype.org/" + val repo = if (version.value.trim.endsWith("SNAPSHOT")) + "snapshots" at nexus + "content/repositories/snapshots" + else + "releases" at nexus + "service/local/staging/deploy/maven2" + Some(repo) +} + +publishMavenStyle := true + +publishArtifact in Test := false + +pomIncludeRepository := { _ => false } + pomExtra := ( <developers> <developer> @@ -60,3 +95,21 @@ pomExtra := ( <connection>scm:git:git@github.com:scala/async.git</connection> </scm> ) + +osgiSettings + +val osgiVersion = version(_.replace('-', '.')) + +OsgiKeys.bundleSymbolicName := s"${organization.value}.${name.value}" + +OsgiKeys.bundleVersion := osgiVersion.value + +OsgiKeys.exportPackage := Seq(s"scala.async.*;version=${version.value}") + +// Sources should also have a nice MANIFEST file +packageOptions in packageSrc := Seq(Package.ManifestAttributes( + ("Bundle-SymbolicName", s"${organization.value}.${name.value}.source"), + ("Bundle-Name", s"${name.value} sources"), + ("Bundle-Version", osgiVersion.value), + ("Eclipse-SourceBundle", s"""${organization.value}.${name.value};version="${osgiVersion.value}";roots:="."""") + )) diff --git a/pending/run/fallback0/fallback0-manual.scala b/pending/run/fallback0/fallback0-manual.scala deleted file mode 100644 index 9bc570d..0000000 --- a/pending/run/fallback0/fallback0-manual.scala +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -import language.{ reflectiveCalls, postfixOps } -import scala.concurrent.{ Future, ExecutionContext, future, Await, Promise } -import scala.concurrent.duration._ -import scala.async.EndTaskException -import scala.async.Async.{ async, await, awaitCps } -import scala.util.continuations.reset - -object TestManual extends App { - - Fallback0ManualSpec.check() - -} - -class TestFallback0ManualClass { - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = future { - x + 2 - } - - def m2(y: Int): Future[Int] = { - val p = Promise[Int]() - future { reset { - val f = m1(y) - var z = 0 - val res = awaitCps(f, p) + 5 - if (res > 0) { - z = 2 - } else { - z = 4 - } - z - } } - p.future - } - - /* that isn't even supported by current CPS plugin - def m3(y: Int): Future[Int] = { - val p = Promise[Int]() - future { reset { - val f = m1(y) - var z = 0 - val res: Option[Int] = Some(5) - res match { - case None => z = 4 - case Some(a) => z = awaitCps(f, p) - 10 - } - z - } } - p.future - } - */ -} - - -object Fallback0ManualSpec extends MinimalScalaTest { - - "An async method" should { - "support await in a simple if-else expression" in { - val o = new TestFallback0ManualClass - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe(2) - } - } - -} diff --git a/pending/run/fallback0/fallback0.scala b/pending/run/fallback0/fallback0.scala deleted file mode 100644 index fdb5568..0000000 --- a/pending/run/fallback0/fallback0.scala +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -import language.{ reflectiveCalls, postfixOps } -import scala.concurrent.{ Future, ExecutionContext, future, Await } -import scala.concurrent.duration._ -import scala.async.Async.{ async, await, awaitCps } - -object Test extends App { - - Fallback0Spec.check() - -} - -class TestFallback0Class { - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = future { - x + 2 - } - - def m2(y: Int): Future[Int] = async { - val f = m1(y) - var z = 0 - val res = await(f) + 5 - if (res > 0) { - z = 2 - } else { - z = 4 - } - z - } -} - - -object Fallback0Spec extends MinimalScalaTest { - - "An async method" should { - "support await in a simple if-else expression" in { - val o = new TestFallback0Class - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe(2) - } - } - -} diff --git a/project/build.properties b/project/build.properties index 2b9d40c..0974fce 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=0.12.1
\ No newline at end of file +sbt.version=0.13.0 diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..6655ada --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("com.typesafe.sbt" % "sbt-osgi" % "0.6.0") diff --git a/release.sh b/release.sh new file mode 100755 index 0000000..ceee287 --- /dev/null +++ b/release.sh @@ -0,0 +1,40 @@ +#! /bin/bash -e +# +# Build, test, and release Scala Async. +# +# Requires credentials: +# +# % cat ~/.sbt/0.13/publish.sbt +# credentials += Credentials("Sonatype Nexus Repository Manager", "oss.sonatype.org", "<user>", "<pass>") +# +# Also requires the sbt-pgp plugin installed globally to provide the `publishSigned` command. +# +# % cat ~/.sbt/0.13/plugins/gpg.sbt +# addSbtPlugin("com.typesafe.sbt" % "sbt-pgp" % "0.8.1") + +function sbt211() { + sbt 'set scalaVersion := "2.11.0-M6"' 'set scalaBinaryVersion := scalaVersion.value' $@ + return $? +} +die () { + echo "$@" + exit 1 +} + +CHECK=";clean;test;publishLocal" +RELEASE=";clean;test;publishSigned" +VERSION=`gsed -rn 's/version :=.*"(.+).*"/\1/p' build.sbt` +[[ -n "$(git status --porcelain)" ]] && die "working directory is not clean!" + +sbt211 $CHECK +sbt $CHECK +sbt $RELEASE +sbt211 $RELEASE + +cat <<EOM +Released! For non-snapshot releases: + - tag: git tag -s -a v$VERSION -m "scala-async $VERSION" + - push tag: git push origin v$VERSION + - close and release the staging repository: https://oss.sonatype.org + - change the version number in build.sbt to a suitable -SNAPSHOT version +EOM
\ No newline at end of file diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala deleted file mode 100644 index 5b9901d..0000000 --- a/src/main/scala/scala/async/AnfTransform.scala +++ /dev/null @@ -1,275 +0,0 @@ - -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async - -import scala.reflect.macros.Context - -private[async] final case class AnfTransform[C <: Context](c: C) { - - import c.universe._ - - val utils = TransformUtils[c.type](c) - - import utils._ - - def apply(tree: Tree): List[Tree] = { - val unique = uniqueNames(tree) - // Must prepend the () for issue #31. - anf.transformToList(Block(List(c.literalUnit.tree), unique)) - } - - private def uniqueNames(tree: Tree): Tree = { - new UniqueNames(tree).transform(tree) - } - - /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name. - * Only modifies names that appear more than once in the tree. - * - * This step is needed to allow us to safely merge blocks during the `inline` transform below. - */ - private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Symbol] = { - class DuplicateNameTraverser extends AsyncTraverser { - val result = collection.mutable.Buffer[Symbol]() - - override def traverse(tree: Tree) { - tree match { - case dt: DefTree => result += dt.symbol - case _ => super.traverse(tree) - } - } - } - val dupNameTraverser = new DuplicateNameTraverser - dupNameTraverser.traverse(tree) - dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol] - } - - /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - - val renamed = collection.mutable.Set[Symbol]() - - override def transform(tree: Tree): Tree = { - tree match { - case defTree: DefTree if repeatedNames(defTree.symbol) => - val trans = super.transform(defTree) - val origName = defTree.symbol.name - val sym = defTree.symbol.asInstanceOf[symtab.Symbol] - val fresh = name.fresh(sym.name.toString) - sym.name = origName match { - case _: TermName => symtab.newTermName(fresh) - case _: TypeName => symtab.newTypeName(fresh) - } - renamed += trans.symbol - val newName = trans.symbol.name - trans match { - case ValDef(mods, name, tpt, rhs) => - treeCopy.ValDef(trans, mods, newName, tpt, rhs) - case Bind(name, body) => - treeCopy.Bind(trans, newName, body) - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => - treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) - case TypeDef(mods, name, tparams, rhs) => - treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs)) - // If we were to allow local classes / objects, we would need to rename here. - case ClassDef(mods, name, tparams, impl) => - treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template]) - case ModuleDef(mods, name, impl) => - treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) - case x => super.transform(x) - } - case Ident(name) => - if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) - else tree - case Select(fun, name) => - if (renamed(tree.symbol)) { - treeCopy.Select(tree, transform(fun), tree.symbol.name) - } else super.transform(tree) - case tt: TypeTree => - val tt1 = tt.asInstanceOf[symtab.TypeTree] - val orig = tt1.original - if (orig != null) tt1.setOriginal(transform(orig.asInstanceOf[Tree]).asInstanceOf[symtab.Tree]) - super.transform(tt) - case _ => super.transform(tree) - } - } - } - - private object trace { - private var indent = -1 - - def indentString = " " * indent - - def apply[T](prefix: String, args: Any)(t: => T): T = { - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) - try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") - val result = t - AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") - result - } finally { - indent -= 1 - } - } - } - - private object inline { - def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { - val stats :+ expr = anf.transformToList(tree) - expr match { - case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await, expr, tree.pos) - stats :+ valDef :+ Ident(valDef.name) - - case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) - case _ => Assign(Ident(varDef.name), orig) - } - val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) - } - - case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => - attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))) - case cd@CaseDef(pat, guard, body) => - attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body))) - } - val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign)) - stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) - } - case _ => - stats :+ expr - } - } - - def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList - - def transformToBlock(tree: Tree): Block = transformToList(tree) match { - case stats :+ expr => Block(stats, expr) - } - - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp)) - vd.setPos(pos) - vd - } - } - - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) - vd.setPos(pos) - vd - } - - private object anf { - - private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { - List(tree) - } else tree match { - case Select(qual, sel) => - val stats :+ expr = inline.transformToList(qual) - stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - - case Applied(fun, targs, argss) if argss.nonEmpty => - // we an assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = inline.transformToList(fun) - def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = - mapArgumentss[List[Tree]](fun, argss) { - case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr) - case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline` - case Arg(expr, _, argName) => - inline.transformToList(expr) match { - case stats :+ expr1 => - val valDef = defineVal(argName, expr1, expr.pos) - (stats :+ valDef, Ident(valDef.name)) - } - } - val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs) - val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) - funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply) - case Block(stats, expr) => - inline.transformToList(stats :+ expr) - - case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)) - } else List(tree) - - case Assign(lhs, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(Assign(lhs, expr)) - - case If(cond, thenp, elsep) => - val condStats :+ condExpr = inline.transformToList(cond) - val thenBlock = inline.transformToBlock(thenp) - val elseBlock = inline.transformToBlock(elsep) - // Typechecking with `condExpr` as the condition fails if the condition - // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems - // we rely on this call to `typeCheck` descending into the branches. - // But, we can get away with typechecking a throwaway `If` tree with the - // original scrutinee and the new branches, and setting that type on - // the real `If` tree. - val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe - condStats :+ - attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = inline.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - // extract local variables for all names bound in `pat`, and rewrite `body` - // to refer to these. - // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. - val block = inline.transformToBlock(body) - val (valDefs, mappings) = (pat collect { - case b@Bind(name, _) => - val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) - val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) - (vd, (b.symbol, newName)) - }).unzip - val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] - attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) - } - // Refer to comments the translation of `If` above. - val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe - val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) - scrutStats :+ typedMatch - - case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = inline.transformToList(fun) - funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol)) - - case _ => - List(tree) - } - } - } -} diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 35d3687..c45a9c6 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -5,181 +5,51 @@ package scala.async import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.internal.annotations.compileTimeOnly -object Async extends AsyncBase { - - import scala.concurrent.Future - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) -} - -object AsyncId extends AsyncBase { - lazy val futureSystem = IdentityFutureSystem - type FS = IdentityFutureSystem.type - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body) -} - /** - * A base class for the `async` macro. Subclasses must provide: + * Async blocks provide a direct means to work with [[scala.concurrent.Future]]. * - * - Concrete types for a given future system - * - Tree manipulations to create and complete the equivalent of Future and Promise - * in that system. - * - The `async` macro declaration itself, and a forwarder for the macro implementation. - * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * For example, to use an API to that fetches as web page to fetch + * two pages and add their lengths: * - * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + * {{{ + * import ExecutionContext.Implicits.global + * import scala.async.Async.{async, await} + * + * def fetchURL(url: URL): Future[String] = ... + * + * val sumLengths: Future[Int] = async { + * val body1 = fetchURL("http://scala-lang.org") + * val body2 = fetchURL("http://docs.scala-lang.org") + * await(body1).length + await(body2).length + * } + * }}} + * + * Note that the in the following program, the second fetch does *not* start + * until after the first. If you need to start tasks in parallel, you must do + * so before `await`-ing a result. + * + * {{{ + * val sumLengths: Future[Int] = async { + * await(fetchURL("http://scala-lang.org")).length + await(fetchURL("http://docs.scala-lang.org")).length + * } + * }}} */ -abstract class AsyncBase { - self => - - type FS <: FutureSystem - val futureSystem: FS +object Async { + /** + * Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of + * a `Future` are needed; this is translated into non-blocking code. + */ + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro internal.ScalaConcurrentAsync.asyncImpl[T] /** - * A call to `await` must be nested in an enclosing `async` block. + * Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `await` block. * - * A call to `await` does not block the current thread, rather it is a delimiter - * used by the enclosing `async` macro. Code following the `await` - * call is executed asynchronously, when the argument of `await` has been completed. - * - * @param awaitable the future from which a value is awaited. - * @tparam T the type of that value. - * @return the value. + * Internally, this will register the remainder of the code in enclosing `async` block as a callback + * in the `onComplete` handler of `awaitable`, and will *not* block a thread. */ @compileTimeOnly("`await` must be enclosed in an `async` block") - def await[T](awaitable: futureSystem.Fut[T]): T = ??? - - protected[async] def fallbackEnabled = false - - def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { - import c.universe._ - - val analyzer = AsyncAnalysis[c.type](c, this) - val utils = TransformUtils[c.type](c) - import utils.{name, defn} - - analyzer.reportUnsupportedAwaits(body.tree) - - // Transform to A-normal form: - // - no await calls in qualifiers or arguments, - // - if/match only used in statement position. - val anfTree: Block = { - val anf = AnfTransform[c.type](c) - val restored = utils.restorePatternMatchingFunctions(body.tree) - val stats1 :+ expr1 = anf(restored) - val block = Block(stats1, expr1) - c.typeCheck(block).asInstanceOf[Block] - } - - // Analyze the block to find locals that will be accessed from multiple - // states of our generated state machine, e.g. a value assigned before - // an `await` and read afterwards. - val renameMap: Map[Symbol, TermName] = { - analyzer.defTreesUsedInSubsequentStates(anfTree).map { - vd => - (vd.symbol, name.fresh(vd.name.toTermName)) - }.toMap - } - - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) - import builder.futureSystemOps - val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) - import asyncBlock.asyncStates - logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) - - // Important to retain the original declaration order here! - val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => - utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => - DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) - } - - val onCompleteHandler = { - Function( - List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), - asyncBlock.onCompleteHandler) - } - val resumeFunTree = asyncBlock.resumeFunTree[T] - - val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) - - lazy val stateMachine: ClassDef = { - val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) - val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) - val applyDefDef: DefDef = { - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody) - } - val apply0DefDef: DefDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) - } - List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) - } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } - ClassDef(NoMods, name.stateMachineT, Nil, template) - } - - def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - - val code: c.Expr[futureSystem.Fut[T]] = { - val isSimple = asyncStates.size == 1 - val tree = - if (isSimple) - Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` - else { - Block(List[Tree]( - stateMachine, - ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)), - futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil)) - ), - futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) - } - c.Expr[futureSystem.Fut[T]](tree) - } - - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") - code - } - - def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { - def location = try { - c.macroApplication.pos.source.path - } catch { - case _: UnsupportedOperationException => - c.macroApplication.pos.toString - } - - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") - states foreach (s => AsyncUtils.vprintln(s)) - } -} - -/** Internal class used by the `async` macro; should not be manually extended by client code */ -abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { - def result$async: Result - - def execContext$async: EC + def await[T](awaitable: Future[T]): T = ??? // No implementation here, as calls to this are translated to `onComplete` by the macro. } diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala deleted file mode 100644 index 4f55f1b..0000000 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async - -import scala.reflect.macros.Context -import scala.collection.mutable - -private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) { - import c.universe._ - - val utils = TransformUtils[c.type](c) - - import utils._ - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - * Must be called on the original tree, not on the ANF transformed tree. - */ - def reportUnsupportedAwaits(tree: Tree): Boolean = { - val analyzer = new UnsupportedAwaitAnalyzer - analyzer.traverse(tree) - analyzer.hasUnsupportedAwaits - } - - /** - * Analyze the contents of an `async` block in order to: - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - * - * Must be called on the ANF transformed tree. - */ - def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { - val analyzer = new AsyncDefinitionUseAnalyzer - analyzer.traverse(tree) - val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct - liftable - } - - private class UnsupportedAwaitAnalyzer extends AsyncTraverser { - var hasUnsupportedAwaits = false - - override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - if (!reportUnsupportedAwait(classDef, s"nested $kind")) { - // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - if (classDef.symbol.asClass.isCaseClass) - c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") - } - } - - override def nestedModule(module: ModuleDef) { - if (!reportUnsupportedAwait(module, "nested object")) { - // local object definitions lead to spurious type errors (because of resetAllAttrs?) - c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") - } - } - - override def nestedMethod(module: DefDef) { - reportUnsupportedAwait(module, "nested method") - } - - override def byNameArgument(arg: Tree) { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function) { - reportUnsupportedAwait(function, "nested function") - } - - override def patMatFunction(tree: Match) { - reportUnsupportedAwait(tree, "nested function") - } - - override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait - tree match { - case Try(_, _, _) if containsAwait => - reportUnsupportedAwait(tree, "try/catch") - super.traverse(tree) - case Return(_) => - c.abort(tree.pos, "return is illegal within a async block") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => - c.abort(tree.pos, "lazy vals are illegal within an async block") - case _ => - super.traverse(tree) - } - } - - /** - * @return true, if the tree contained an unsupported await. - */ - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { - val badAwaits: List[RefTree] = tree collect { - case rt: RefTree if isAwait(rt) => rt - } - badAwaits foreach { - tree => - reportError(tree.pos, s"await must not be used under a $whyUnsupported.") - } - badAwaits.nonEmpty - } - - private def reportError(pos: Position, msg: String) { - hasUnsupportedAwaits = true - if (!asyncBase.fallbackEnabled) - c.error(pos, msg) - } - } - - private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { - private var chunkId = 0 - - private def nextChunk() = chunkId += 1 - - private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - - val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() - val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() - - override def nestedMethod(defDef: DefDef) { - nestedMethodsToLift += defDef - markReferencedVals(defDef) - } - - override def function(function: Function) { - markReferencedVals(function) - } - - override def patMatFunction(tree: Match) { - markReferencedVals(tree) - } - - private def markReferencedVals(tree: Tree) { - tree foreach { - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) => - valDefsToLift += vd // lift all vals referred to by nested functions. - case _ => - } - case _ => - } - } - - override def traverse(tree: Tree) = { - tree match { - case If(cond, thenp, elsep) if tree exists isAwait => - traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => - traverseChunks(selector :: cases) - case LabelDef(name, params, rhs) if rhs exists isAwait => - traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => - super.traverse(tree) - nextChunk() - case vd: ValDef => - super.traverse(tree) - valDefChunkId += (vd.symbol -> (vd -> chunkId)) - val isPatternBinder = vd.name.toString.contains(name.bindSuffix) - if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd - case as: Assign => - if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) - - // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) - valDefsToLift += vd - } - super.traverse(tree) - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) if defChunkId != chunkId => - valDefsToLift += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) - } - } - - private def traverseChunks(trees: List[Tree]) { - trees.foreach { - t => traverse(t); nextChunk() - } - } - } - -} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala deleted file mode 100644 index ca46a83..0000000 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ -package scala.async - -import scala.reflect.macros.Context -import scala.collection.mutable.ListBuffer -import collection.mutable -import language.existentials - -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { - builder => - - val utils = TransformUtils[c.type](c) - - import c.universe._ - import utils._ - import defn._ - - lazy val futureSystemOps = futureSystem.mkOps(c) - - val stateAssigner = new StateAssigner - val labelDefStates = collection.mutable.Map[Symbol, Int]() - - trait AsyncState { - def state: Int - - def mkHandlerCaseForState: CaseDef - - def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None - - def stats: List[Tree] - - final def body: c.Tree = stats match { - case stat :: Nil => stat - case init :+ last => Block(init, last) - } - } - - /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) - extends AsyncState { - - def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) - - override val toString: String = - s"AsyncState #$state, next = $nextState" - } - - /** A sequence of statements with a conditional transition to the next state, which will represent - * a branch of an `if` or a `match`. - */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { - override def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats) - - override val toString: String = - s"AsyncStateWithoutAwait #$state" - } - - /** A sequence of statements that concludes with an `await` call. The `onComplete` - * handler will unconditionally transition to `nestState`.`` - */ - final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) - extends AsyncState { - - override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) - } - - override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { - val tryGetTree = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result$async.complete(tr.asInstanceOf[Try[T]]) - * else { - * <resultName> = tr.get.asInstanceOf[<resultType>] - * <nextState> - * <mkResumeApply> - * } - */ - val ifIsFailureTree = - If(Select(Ident(name.tr), Try_isFailure), - futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), - c.Expr[scala.util.Try[T]]( - TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")), - List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply) - ) - - Some(mkHandlerCase(state, List(ifIsFailureTree))) - } - - override val toString: String = - s"AsyncStateWithAwait #$state, next = $nextState" - } - - /* - * Builder for a single state of an async method. - */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { - /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() - /** The state of the target of a LabelDef application (while loop jump) */ - private var nextJumpState: Option[Int] = None - - private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) - - def +=(stat: c.Tree): this.type = { - assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - def addStat() = stats += renameReset(stat) - stat match { - case _: DefDef => // these have been lifted. - case Apply(fun, Nil) => - labelDefStates get fun.symbol match { - case Some(nextState) => nextJumpState = Some(nextState) - case None => addStat() - } - case _ => addStat() - } - this - } - - def resultWithAwait(awaitable: Awaitable, - nextState: Int): AsyncState = { - val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) - } - - def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) - } - - def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { - // 1. build changed if-else tree - // 2. insert that tree at the end of the current state - val cond = renameReset(condTree) - def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply) - this += If(cond, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) - } - - /** - * Build `AsyncState` ending with a match expression. - * - * The cases of the match simply resume at the state of their corresponding right-hand side. - * - * @param scrutTree tree of the scrutinee - * @param cases list of case definitions - * @param caseStates starting state of the right-hand side of the each case - * @return an `AsyncState` representing the match expression - */ - def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { - // 1. build list of changed cases - val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { - case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) - case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") - } - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) - } - // 2. insert changed match tree at the end of the current state - this += Match(renameReset(scrutTree), newCases) - new AsyncStateWithoutAwait(stats.toList, state) - } - - def resultWithLabel(startLabelState: Int): AsyncState = { - this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply) - new AsyncStateWithoutAwait(stats.toList, state) - } - - override def toString: String = { - val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait" - } - } - - /** - * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). - * - * @param stats a list of expressions - * @param expr the last expression of the block - * @param startState the start state - * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names - */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { - val asyncStates = ListBuffer[AsyncState]() - - var stateBuilder = new AsyncStateBuilder(startState, toRename) - var currState = startState - - /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ - def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if isAwait(fun) => true - case _ => false - }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException - - def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { - val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) - } - - import stateAssigner.nextState - - // populate asyncStates - for (stat <- stats) stat match { - // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => - val afterAwaitState = nextState() - val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe) - asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await - currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => - checkForUnsupportedAwait(rhs) - stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) - - case If(cond, thenp, elsep) if stat exists isAwait => - checkForUnsupportedAwait(cond) - - val thenStartState = nextState() - val elseStartState = nextState() - val afterIfState = nextState() - - asyncStates += - // the two Int arguments are the start state of the then branch and the else branch, respectively - stateBuilder.resultWithIf(cond, thenStartState, elseStartState) - - List((thenp, thenStartState), (elsep, elseStartState)) foreach { - case (branchTree, state) => - val builder = nestedBlockBuilder(branchTree, state, afterIfState) - asyncStates ++= builder.asyncStates - } - - currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case Match(scrutinee, cases) if stat exists isAwait => - checkForUnsupportedAwait(scrutinee) - - val caseStates = cases.map(_ => nextState()) - val afterMatchState = nextState() - - asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates) - - for ((cas, num) <- cases.zipWithIndex) { - val (stats, expr) = statsAndExpr(cas.body) - val stats1 = stats.dropWhile(isSyntheticBindVal) - val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) - asyncStates ++= builder.asyncStates - } - - currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case ld@LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() - val afterLabelState = nextState() - asyncStates += stateBuilder.resultWithLabel(startLabelState) - labelDefStates(ld.symbol) = startLabelState - val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) - asyncStates ++= builder.asyncStates - - currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, toRename) - case _ => - checkForUnsupportedAwait(stat) - stateBuilder += stat - } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr - val lastState = stateBuilder.resultSimple(endState) - asyncStates += lastState - } - - trait AsyncBlock { - def asyncStates: List[AsyncState] - - def onCompleteHandler[T: c.WeakTypeTag]: Tree - - def resumeFunTree[T]: Tree - } - - def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = { - val Block(stats, expr) = block - val startState = stateAssigner.nextState() - val endState = Int.MaxValue - - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) - - new AsyncBlock { - def asyncStates = blockBuilder.asyncStates.toList - - def mkCombinedHandlerCases[T]: List[CaseDef] = { - val caseForLastState: CaseDef = { - val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) - mkHandlerCase(lastState.state, rhs.tree) - } - asyncStates.toList match { - case s :: Nil => - List(caseForLastState) - case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState - initCases :+ caseForLastState - } - } - - val initStates = asyncStates.init - - /** - * // assumes tr: Try[Any] is in scope. - * // - * state match { - * case 0 => { - * x11 = tr.get.asInstanceOf[Double]; - * state = 1; - * resume() - * } - */ - def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) - - /** - * def resume(): Unit = { - * try { - * state match { - * case 0 => { - * f11 = exprReturningFuture - * f11.onComplete(onCompleteHandler)(context) - * } - * ... - * } - * } catch { - * case NonFatal(t) => result.failure(t) - * } - * } - */ - def resumeFunTree[T]: Tree = - DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), - Try( - Match(Ident(name.state), mkCombinedHandlerCases[T]), - List( - CaseDef( - Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), - EmptyTree, - Block(List({ - val t = c.Expr[Throwable](Ident(name.tr)) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree - }), c.literalUnit.tree))), EmptyTree)) - } - } - - private def isSyntheticBindVal(tree: Tree) = tree match { - case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) - case _ => false - } - - private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - - private val internalSyms = origTree.collect { - case dt: DefTree => dt.symbol - } - - private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) - - private def mkResumeApply = Apply(Ident(name.resume), Nil) - - private def mkStateTree(nextState: Int): c.Tree = - Assign(Ident(name.state), c.literal(nextState).tree) - - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs, c.literalUnit.tree)) - - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) -} diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala deleted file mode 100644 index a050bec..0000000 --- a/src/main/scala/scala/async/FutureSystem.scala +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ -package scala.async - -import scala.language.higherKinds - -import scala.reflect.macros.Context - -/** - * An abstraction over a future system. - * - * Used by the macro implementations in [[scala.async.AsyncBase]] to - * customize the code generation. - * - * The API mirrors that of `scala.concurrent.Future`, see the instance - * [[scala.async.ScalaConcurrentFutureSystem]] for an example of how - * to implement this. - */ -trait FutureSystem { - /** A container to receive the final value of the computation */ - type Prom[A] - /** A (potentially in-progress) computation */ - type Fut[A] - /** An execution context, required to create or register an on completion callback on a Future. */ - type ExecContext - - trait Ops { - val context: reflect.macros.Context - - import context.universe._ - - /** Lookup the execution context, typically with an implicit search */ - def execContext: Expr[ExecContext] - - def promType[A: WeakTypeTag]: Type - def execContextType: Type - - /** Create an empty promise */ - def createProm[A: WeakTypeTag]: Expr[Prom[A]] - - /** Extract a future from the given promise. */ - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] - - /** Construct a future to asynchronously compute the given expression */ - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] - - /** Register an call back to run on completion of the given future */ - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] - - /** Complete a promise with a value */ - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] - - def spawn(tree: context.Tree): context.Tree = - future(context.Expr[Unit](tree))(execContext).tree - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] - } - - def mkOps(c: Context): Ops { val context: c.type } -} - -object ScalaConcurrentFutureSystem extends FutureSystem { - - import scala.concurrent._ - - type Prom[A] = Promise[A] - type Fut[A] = Future[A] - type ExecContext = ExecutionContext - - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c - - import context.universe._ - - def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match { - case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext") - case context => context - }) - - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] - def execContextType: Type = c.weakTypeOf[ExecutionContext] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - Promise[A]() - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.future - } - - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { - Future(a.splice)(execContext.splice) - } - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - future.splice.onComplete(fun.splice)(execContext.splice) - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.complete(value.splice) - context.literalUnit.splice - } - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { - future.splice.asInstanceOf[Fut[A]] - } - } -} - -/** - * A trivial implementation of [[scala.async.FutureSystem]] that performs computations - * on the current thread. Useful for testing. - */ -object IdentityFutureSystem extends FutureSystem { - - class Prom[A](var a: A) - - type Fut[A] = A - type ExecContext = Unit - - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c - - import context.universe._ - - def execContext: Expr[ExecContext] = c.literalUnit - - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] - def execContextType: Type = c.weakTypeOf[Unit] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - new Prom(null.asInstanceOf[A]) - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.a - } - - def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - fun.splice.apply(util.Success(future.splice)) - context.literalUnit.splice - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.a = value.splice.get - context.literalUnit.splice - } - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? - } -} diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala deleted file mode 100644 index ebd546f..0000000 --- a/src/main/scala/scala/async/TransformUtils.scala +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ -package scala.async - -import scala.reflect.macros.Context -import reflect.ClassTag - -/** - * Utilities used in both `ExprBuilder` and `AnfTransform`. - */ -private[async] final case class TransformUtils[C <: Context](c: C) { - - import c.universe._ - - object name { - def suffix(string: String) = string + "$async" - - def suffixedName(prefix: String) = newTermName(suffix(prefix)) - - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - val stateMachine = newTermName(fresh("stateMachine")) - val stateMachineT = stateMachine.toTypeName - val apply = newTermName("apply") - val applyOrElse = newTermName("applyOrElse") - val tr = newTermName("tr") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" - val bindSuffix = "$bind" - - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) - - def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") - } - - def defaultValue(tpe: Type): Literal = { - val defaultValue: Any = - if (tpe <:< definitions.BooleanTpe) false - else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 - else if (tpe <:< definitions.AnyValTpe) 0 - else null - Literal(Constant(defaultValue)) - } - - def isAwait(fun: Tree) = - fun.symbol == defn.Async_await - - /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node - * referring to the corresponding new name - */ - def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = { - val renamer = new Transformer { - override def transform(tree: Tree) = tree match { - case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_)) - case tt: TypeTree if tt.original != EmptyTree && tt.original != null => - // We also have to apply our renaming transform on originals of TypeTrees. - // TODO 2.10.1 Can we find a cleaner way? - val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val tt1 = tt.asInstanceOf[symTab.TypeTree] - tt1.setOriginal(transform(tt.original).asInstanceOf[symTab.Tree]) - super.transform(tree) - case _ => super.transform(tree) - } - } - renamer.transform(tree) - } - - /** Descends into the regions of the tree that are subject to the - * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descent stops, - * and `nestedClass` etc are invoked. - */ - trait AsyncTraverser extends Traverser { - def nestedClass(classDef: ClassDef) { - } - - def nestedModule(module: ModuleDef) { - } - - def nestedMethod(module: DefDef) { - } - - def byNameArgument(arg: Tree) { - } - - def function(function: Function) { - } - - def patMatFunction(tree: Match) { - } - - override def traverse(tree: Tree) { - tree match { - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Applied(fun, targs, argss) if argss.nonEmpty => - val isInByName = isByName(fun) - for ((args, i) <- argss.zipWithIndex) { - for ((arg, j) <- args.zipWithIndex) { - if (!isInByName(i, j)) traverse(arg) - else byNameArgument(arg) - } - } - traverse(fun) - case _ => super.traverse(tree) - } - } - } - - private lazy val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) - } - - def isByName(fun: Tree): ((Int, Int) => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true - else { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss - val byNamess = paramss.map(_.map(_.isByNameParam)) - (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) - } - } - def argName(fun: Tree): ((Int, Int) => String) = { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss - val namess = paramss.map(_.map(_.name.toString)) - (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") - } - - object Applied { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - - def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = { - val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree] - Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]])) - } - } - - def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { - case Block(stats, expr) => (stats, expr) - case _ => (List(tree), Literal(Constant(()))) - } - - def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) - } - - def emptyConstructor: DefDef = { - val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) - DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree)) - } - - def applied(className: String, types: List[Type]): AppliedTypeTree = - AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_))) - - object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) - - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - val Try_isFailure = methodSym(reify((null: scala.util.Try[Any]).isFailure)) - - val TryClass = c.mirror.staticClass("scala.util.Try") - val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - - private def asyncMember(name: String) = { - val asyncMod = c.mirror.staticClass("scala.async.AsyncBase") - val tpe = asyncMod.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) - } - - val Async_await = asyncMember("await") - } - - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ - private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) - tree2.collect { - case s: SymTree if s.symbol.isMethod => s.symbol - }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) - } - - /** - * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like: - * we don't want to copy types and symbols to the new trees in some cases. - * - * Instead, we just copy positions and attachments. - */ - def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = { - tree.setPos(orig.pos) - for (att <- orig.attachments.all) - tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass)) - tree - } - - def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) = - new ResetInternalAttrs(internalSyms.toSet).transform(tree) - - /** - * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]] - * - * A transformer which resets symbol and tpe fields of all nodes in a given tree, - * with special treatment of: - * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset - * to empty if it started out empty or refers to local symbols (which are erased). - * `TypeApply` nodes: are deleted if type arguments end up reverted to empty - * - * `This` and `Ident` nodes referring to an external symbol are ''not'' reset. - */ - private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer { - - import language.existentials - - override def transform(tree: Tree): Tree = super.transform { - def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol) - - tree match { - case tpt: TypeTree => resetTypeTree(tpt) - case TypeApply(fn, args) - if args map transform exists (_.isEmpty) => transform(fn) - case EmptyTree => tree - case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block - case _ => resetTree(tree) - } - } - - private def resetTypeTree(tpt: TypeTree): Tree = { - if (tpt.original != null) - transform(tpt.original) - else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) { - val dupl = tpt.duplicate - dupl.tpe = null - dupl - } - else tpt - } - - private def resetTree(tree: Tree): Tree = { - val hasSymbol: Boolean = { - val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}] - reflectInternalTree.hasSymbol - } - val dupl = tree.duplicate - if (hasSymbol) - dupl.symbol = NoSymbol - dupl.tpe = null - dupl - } - } - - /** - * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }` - * with `Match(EmptyTree, cases`. - * - * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees. - */ - // TODO Reference JIRA issue. - final def restorePatternMatchingFunctions(tree: Tree) = - RestorePatternMatchingFunctions transform tree - - private object RestorePatternMatchingFunctions extends Transformer { - - import language.existentials - val DefaultCaseName: TermName = "defaultCase$" - - override def transform(tree: Tree): Tree = { - val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet] - def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC - - /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know - * whether the user provided cases are exhaustive. */ - def isSyntheticDefaultCase(cdef: CaseDef) = cdef match { - case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true - case _ => false - } - tree match { - case Block( - (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil, - Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) => - val restored = (body collectFirst { - case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) => - val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef)) - val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]] - Match(EmptyTree, transformedCases) - }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body")) - restored - case t => super.transform(t) - } - } - } - - def isSafeToInline(tree: Tree) = { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - val castTree = tree.asInstanceOf[symtab.Tree] - treeInfo.isExprSafeToInline(castTree) - } - - /** Map a list of arguments to: - * - A list of argument Trees - * - A list of auxillary results. - * - * The function unwraps and rewraps the `arg :_*` construct. - * - * @param args The original argument trees - * @param f A function from argument (with '_*' unwrapped) and argument index to argument. - * @tparam A The type of the auxillary result - */ - private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { - args match { - case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => - val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip - val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) - (a, exprs) - case args => - args.zipWithIndex.map(f.tupled).unzip - } - } - - case class Arg(expr: Tree, isByName: Boolean, argName: String) - - /** - * Transform a list of argument lists, producing the transformed lists, and lists of auxillary - * results. - * - * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will - * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. - * - * @param fun The function being applied - * @param argss The argument lists - * @return (auxillary results, mapped argument trees) - */ - def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { - val isByNamess: (Int, Int) => Boolean = isByName(fun) - val argNamess: (Int, Int) => String = argName(fun) - argss.zipWithIndex.map { case (args, i) => - mapArguments[A](args) { - (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) - } - }.unzip - } -} diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala deleted file mode 100644 index a669cfa..0000000 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package continuations - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.util.continuations._ - -trait AsyncBaseWithCPSFallback extends AsyncBase { - - /* Fall-back for `await` using CPS plugin. - * - * Note: This method is public, but is intended only for internal use. - */ - def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[futureSystem.Fut[Any]] - - override protected[async] def fallbackEnabled = true - - /* Implements `async { ... }` using the CPS plugin. - */ - protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { - import c.universe._ - - def lookupMember(name: String) = { - val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") - val tpe = asyncTrait.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) - } - - AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") - - val utils = TransformUtils[c.type](c) - val futureSystemOps = futureSystem.mkOps(c) - val awaitSym = utils.defn.Async_await - val awaitFallbackSym = lookupMember("awaitFallback") - - // replace `await` invocations with `awaitFallback` invocations - val awaitReplacer = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym => - val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe))) - treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate))) - case _ => - super.transform(tree) - } - } - val bodyWithAwaitFallback = awaitReplacer.transform(body.tree) - - /* generate an expression that looks like this: - reset { - val f = future { ... } - ... - val x = awaitFallback(f) - ... - future { expr } - }.asInstanceOf[Future[T]] - */ - - val bodyWithFuture = { - val tree = bodyWithAwaitFallback match { - case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) - case expr => futureSystemOps.spawn(expr) - } - c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) - } - - val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { - reset { bodyWithFuture.splice } - } - val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) - - AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") - bodyWithCast - } - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { - AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") - - val analyzer = AsyncAnalysis[c.type](c, this) - - if (!analyzer.reportUnsupportedAwaits(body.tree)) - super.asyncImpl[T](c)(body) // no unsupported awaits - else - cpsBasedAsyncImpl[T](c)(body) // fallback to CPS - } -} diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala deleted file mode 100644 index fe6e1a6..0000000 --- a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package continuations - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.concurrent.Future - -trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback - -object AsyncWithCPSFallback extends AsyncWithCPSFallback { - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) -} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala deleted file mode 100644 index 922d1ac..0000000 --- a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package continuations - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.concurrent.Future - -trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback - -object CPSBasedAsync extends CPSBasedAsync { - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) - -} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala deleted file mode 100644 index 4e8ec80..0000000 --- a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package continuations - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.util.continuations._ - -/* Specializes `AsyncBaseWithCPSFallback` to always fall back to CPS, yielding a purely CPS-based - * implementation of async/await. - */ -trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback { - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = - super.cpsBasedAsyncImpl[T](c)(body) - -} diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala deleted file mode 100644 index 018ad05..0000000 --- a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package continuations - -import scala.util.continuations._ -import scala.concurrent.{Future, Promise, ExecutionContext} - -trait ScalaConcurrentCPSFallback { - self: AsyncBaseWithCPSFallback => - - import ExecutionContext.Implicits.global - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - /* Fall-back for `await` when it is called at an unsupported position. - */ - override def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[Future[Any]] = - shift { - (k: (T => Future[Any])) => - val fr = Promise[Any]() - awaitable onComplete { - case tr => fr completeWith k(tr.get) - } - fr.future - } - -} diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala new file mode 100644 index 0000000..8518cf5 --- /dev/null +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -0,0 +1,284 @@ + +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.tools.nsc.Global +import scala.Predef._ + +private[async] trait AnfTransform { + self: AsyncMacro => + + import global._ + import reflect.internal.Flags._ + + def anfTransform(tree: Tree): Block = { + // Must prepend the () for issue #31. + val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) + + new SelectiveAnfTransform().transform(block) + } + + sealed abstract class AnfMode + + case object Anf extends AnfMode + + case object Linearizing extends AnfMode + + final class SelectiveAnfTransform extends MacroTypingTransformer { + var mode: AnfMode = Anf + + def blockToList(tree: Tree): List[Tree] = tree match { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + + def listToBlock(trees: List[Tree]): Block = trees match { + case trees @ (init :+ last) => + val pos = trees.map(_.pos).reduceLeft(_ union _) + Block(init, last).setType(last.tpe).setPos(pos) + } + + override def transform(tree: Tree): Block = { + def anfLinearize: Block = { + val trees: List[Tree] = mode match { + case Anf => anf._transformToList(tree) + case Linearizing => linearize._transformToList(tree) + } + listToBlock(trees) + } + tree match { + case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => + atOwner(tree.symbol)(anfLinearize) + case _: ModuleDef => + atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) + case _ => + anfLinearize + } + } + + private object linearize { + def transformToList(tree: Tree): List[Tree] = { + mode = Linearizing; blockToList(transform(tree)) + } + + def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val stats :+ expr = anf.transformToList(tree) + def statsExprUnit = + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + expr match { + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal(name.await, expr, tree.pos) + stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe).setPos(tree.pos) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + statsExprUnit + } else { + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) { + def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe) + orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case _ => Assign(Ident(varDef.symbol), cast(orig)) + } + } + val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe) + stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setType(tree.tpe).setPos(tree.pos) + } + case LabelDef(name, params, rhs) => + statsExprUnit + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + statsExprUnit + } + else { + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + def typedAssign(lhs: Tree) = + localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe))) + val casesWithAssign = cases map { + case cd@CaseDef(pat, guard, body) => + val newBody = body match { + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)).setType(definitions.UnitTpe) + case _ => typedAssign(body) + } + treeCopy.CaseDef(cd, pat, guard, newBody).setType(definitions.UnitTpe) + } + val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign).setType(definitions.UnitTpe) + require(matchWithAssign.tpe != null, matchWithAssign) + stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setPos(tree.pos).setType(tree.tpe) + } + case _ => + stats :+ expr + } + } + + private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) + ValDef(sym, gen.mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) + } + } + + private object trace { + private var indent = -1 + + private def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe)) + changeOwner(lhs, currentOwner, sym) + ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) + } + + private object anf { + def transformToList(tree: Tree): List[Tree] = { + mode = Anf; blockToList(transform(tree)) + } + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val containsAwait = tree exists isAwait + if (!containsAwait) { + tree match { + case Block(stats, expr) => + // avoids nested block in `while(await(false)) ...`. + // TODO I think `containsAwait` really should return true if the code contains a label jump to an enclosing + // while/doWhile and there is an await *anywhere* inside that construct. + stats :+ expr + case _ => List(tree) + } + } else tree match { + case Select(qual, sel) => + val stats :+ expr = linearize.transformToList(qual) + stats :+ treeCopy.Select(tree, expr, sel) + + case Throw(expr) => + val stats :+ expr1 = linearize.transformToList(expr) + stats :+ treeCopy.Throw(tree, expr1) + + case Typed(expr, tpt) => + val stats :+ expr1 = linearize.transformToList(expr) + stats :+ treeCopy.Typed(tree, expr1, tpt) + + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + // we can assume that no await call appears in a by-name argument position, + // this has already been checked. + val funStats :+ simpleFun = linearize.transformToList(fun) + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) + case Arg(expr, _, argName) => + linearize.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr1.pos) + require(valDef.tpe != null, valDef) + val stats1 = stats :+ valDef + (stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))) + } + } + + def copyApplied(tree: Tree, depth: Int): Tree = { + tree match { + case TypeApply(_, targs) => treeCopy.TypeApply(tree, simpleFun, targs) + case _ if depth == 0 => simpleFun + case Apply(fun, args) => + val newTypedArgs = map2(args.map(_.pos), argExprss(depth - 1))((pos, arg) => localTyper.typedPos(pos)(arg)) + treeCopy.Apply(tree, copyApplied(fun, depth - 1), newTypedArgs) + } + } + + + /** The depth of the nested applies: e.g. Apply(Apply(Apply(_, _), _), _) + * has depth 3. Continues through type applications (without counting them.) + */ + def applyDepth: Int = { + def loop(tree: Tree): Int = tree match { + case Apply(fn, _) => 1 + loop(fn) + case TypeApply(fn, _) => loop(fn) + case AppliedTypeTree(fn, _) => loop(fn) + case _ => 0 + } + loop(tree) + } + + val typedNewApply = copyApplied(tree, applyDepth) + + funStats ++ argStatss.flatten.flatten :+ typedNewApply + + case Block(stats, expr) => + (stats :+ expr).flatMap(linearize.transformToList) + + case ValDef(mods, name, tpt, rhs) => + if (rhs exists isAwait) { + val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) + stats.foreach(changeOwner(_, currOwner, currOwner.owner)) + stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) + } else List(tree) + + case Assign(lhs, rhs) => + val stats :+ expr = linearize.transformToList(rhs) + stats :+ treeCopy.Assign(tree, lhs, expr) + + case If(cond, thenp, elsep) => + val condStats :+ condExpr = linearize.transformToList(cond) + val thenBlock = linearize.transformToBlock(thenp) + val elseBlock = linearize.transformToBlock(elsep) + condStats :+ treeCopy.If(tree, condExpr, thenBlock, elseBlock) + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = linearize.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. + val block = linearize.transformToBlock(body) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos) + (vd, (b.symbol, vd.symbol)) + }).unzip + val (from, to) = mappings.unzip + val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] + val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) + treeCopy.CaseDef(tree, pat, guard, newBlock) + } + scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) + + case LabelDef(name, params, rhs) => + List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = linearize.transformToList(fun) + funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) + + case _ => + List(tree) + } + } + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala new file mode 100644 index 0000000..69e4c3c --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable + +trait AsyncAnalysis { + self: AsyncMacro => + + import global._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree): Unit = { + val analyzer = new UnsupportedAwaitAnalyzer + analyzer.traverse(tree) + analyzer.hasUnsupportedAwaits + } + + private class UnsupportedAwaitAnalyzer extends AsyncTraverser { + var hasUnsupportedAwaits = false + + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested ${kind}") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def nestedMethod(defDef: DefDef) { + reportUnsupportedAwait(defDef, "nested method") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + override def patMatFunction(tree: Match) { + reportUnsupportedAwait(tree, "nested function") + } + + override def traverse(tree: Tree) { + def containsAwait = tree exists isAwait + tree match { + case Try(_, _, _) if containsAwait => + reportUnsupportedAwait(tree, "try/catch") + super.traverse(tree) + case Return(_) => + abort(tree.pos, "return is illegal within a async block") + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + // TODO lift this restriction + abort(tree.pos, "lazy vals are illegal within an async block") + case CaseDef(_, guard, _) if guard exists isAwait => + // TODO lift this restriction + reportUnsupportedAwait(tree, "pattern guard") + case _ => + super.traverse(tree) + } + } + + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + reportError(tree.pos, s"await must not be used under a $whyUnsupported.") + } + badAwaits.nonEmpty + } + + private def reportError(pos: Position, msg: String) { + hasUnsupportedAwaits = true + abort(pos, msg) + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala new file mode 100644 index 0000000..3bb3b99 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.internal.annotations.compileTimeOnly +import scala.reflect.macros.Context +import scala.reflect.api.Universe + +/** + * A base class for the `async` macro. Subclasses must provide: + * + * - Concrete types for a given future system + * - Tree manipulations to create and complete the equivalent of Future and Promise + * in that system. + * - The `async` macro declaration itself, and a forwarder for the macro implementation. + * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * + * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + */ +abstract class AsyncBase { + self => + + type FS <: FutureSystem + val futureSystem: FS + + /** + * A call to `await` must be nested in an enclosing `async` block. + * + * A call to `await` does not block the current thread, rather it is a delimiter + * used by the enclosing `async` macro. Code following the `await` + * call is executed asynchronously, when the argument of `await` has been completed. + * + * @param awaitable the future from which a value is awaited. + * @tparam T the type of that value. + * @return the value. + */ + @compileTimeOnly("`await` must be enclosed in an `async` block") + def await[T](awaitable: futureSystem.Fut[T]): T = ??? + + def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + val asyncMacro = AsyncMacro(c, self) + + val isPresentationCompiler = asyncMacro.global.forInteractive + + val code = asyncMacro.asyncTransform[T]( + body.tree.asInstanceOf[asyncMacro.global.Tree], + execContext.tree.asInstanceOf[asyncMacro.global.Tree] + )(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]).asInstanceOf[Tree] + + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") + val result = if (isPresentationCompiler) { + asyncMacro.suppressExpansion() + c.macroApplication + } else { + // Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges + for (t <- code) + t.pos = t.pos.makeTransparent + code + } + c.Expr[futureSystem.Fut[T]](result) + } + + protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = { + import u._ + asyncMacroSymbol.owner.typeSignature.member(newTermName("await")) + } + + protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = + u.reify { () } +} diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala new file mode 100644 index 0000000..5343a82 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import language.experimental.macros +import scala.reflect.macros.Context +import scala.reflect.api.Universe +import scala.reflect.internal.SymbolTable + +object AsyncId extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) +} + +object AsyncTestLV extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) + + var log: List[(String, Any)] = List() + def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log) + def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log) + def clear() = log = Nil + + def apply(name: String, v: Any): Unit = + log ::= (name -> v) + + protected[async] override def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = + u.reify { scala.async.internal.AsyncTestLV(name.splice, v.splice) } +} + +/** + * A trivial implementation of [[FutureSystem]] that performs computations + * on the current thread. Useful for testing. + */ +object IdentityFutureSystem extends FutureSystem { + + class Prom[A] { + var a: A = _ + } + + type Fut[A] = A + type ExecContext = Unit + type Tryy[A] = scala.util.Try[A] + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] + def execContextType: Type = weakTypeOf[Unit] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + new Prom[A]() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.a + } + + def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + fun.splice.apply(util.Success(future.splice)) + Expr[Unit](Literal(Constant(()))).splice + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] = reify { + prom.splice.a = value.splice.get + Expr[Unit](Literal(Constant(()))).splice + } + + def tryyIsFailure[A](tryy: Expr[Tryy[A]]): Expr[Boolean] = reify { + tryy.splice.isFailure + } + + def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] = reify { + tryy.splice.get + } + def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] = reify { + scala.util.Success[A](a.splice) + } + def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] = reify { + scala.util.Failure[A](a.splice) + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala new file mode 100644 index 0000000..79f544c --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -0,0 +1,51 @@ +package scala.async.internal + +import scala.tools.nsc.Global +import scala.tools.nsc.transform.TypingTransformers + +object AsyncMacro { + def apply(c: reflect.macros.Context, base: AsyncBase): AsyncMacro = { + import language.reflectiveCalls + val powerContext = c.asInstanceOf[c.type { val universe: Global; val callsiteTyper: universe.analyzer.Typer }] + new AsyncMacro { + val global: powerContext.universe.type = powerContext.universe + val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + // This member is required by `AsyncTransform`: + val asyncBase: AsyncBase = base + // These members are required by `ExprBuilder`: + val futureSystem: FutureSystem = base.futureSystem + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem.mkOps(global) + } + } +} + +private[async] trait AsyncMacro + extends TypingTransformers + with AnfTransform with TransformUtils with Lifter + with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { + + val global: Global + val callSiteTyper: global.analyzer.Typer + val macroApplication: global.Tree + + lazy val macroPos = macroApplication.pos.makeTransparent + def atMacroPos(t: global.Tree) = global.atPos(macroPos)(t) + + def suppressExpansion() { + // Have your cake : Scala IDE sees original trees and hyperlinking, etc within async blocks "Just Works" + // Eat it too : (domain specific errors like "unsupported use of await" + // + // TODO roll this idea back into scala/scala + + import global.Tree + type Suppress = { def suppressMacroExpansion(a: Tree): Tree } + try { + global.asInstanceOf[Suppress].suppressMacroExpansion(macroApplication) + } catch { + case _: NoSuchMethodException => + global.analyzer.asInstanceOf[Suppress].suppressMacroExpansion(macroApplication) + } + } + +} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala new file mode 100644 index 0000000..352ef71 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -0,0 +1,220 @@ +package scala.async.internal + +trait AsyncTransform { + self: AsyncMacro => + + import global._ + + val asyncBase: AsyncBase + + def asyncTransform[T](body: Tree, execContext: Tree) + (resultType: WeakTypeTag[T]): Tree = { + + // We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce + // warnings about non-conformant LUBs. See SI-7694 + // This implicit propagates the annotated type in the type tag. + implicit val uncheckedBoundsResultTag: WeakTypeTag[T] = WeakTypeTag[T](rootMirror, FixedMirrorTypeCreator(rootMirror, uncheckedBounds(resultType.tpe))) + + reportUnsupportedAwaits(body) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree0: Block = anfTransform(body) + + val anfTree = futureSystemOps.postAnfTransform(anfTree0) + + val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) + + val applyDefDefDummyBody: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + + // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. + val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) + val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) + + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + val extraValDef: ValDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef, extraValDef) + } + + val template = Template(List(typeOf[(scala.util.Try[Any] => Unit)], typeOf[() => Unit]).map(TypeTree(_)), emptyValDef, body) + + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) + callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(())))) + t + } + + val stateMachineClass = stateMachine.symbol + val asyncBlock: AsyncBlock = { + val symLookup = new SymLookup(stateMachineClass, applyDefDefDummyBody.vparamss.head.head.symbol) + buildAsyncBlock(anfTree, symLookup) + } + + logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + + val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates) + + // live variables analysis + // the result map indicates in which states a given field should be nulled out + val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) + + for ((state, flds) <- assignsOf) { + val assigns = flds.map { fld => + val fieldSym = fld.symbol + Block( + List( + asyncBase.nullOut(global)(Expr[String](Literal(Constant(fieldSym.name.toString))), Expr[Any](Ident(fieldSym))).tree + ), + Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), gen.mkZero(fieldSym.info)) + ) + } + val asyncState = asyncBlock.asyncStates.find(_.state == state).get + asyncState.stats = assigns ++ asyncState.stats + } + + def startStateMachine: Tree = { + val stateMachineSpliced: Tree = spliceMethodBodies( + liftedFields, + stateMachine, + atMacroPos(asyncBlock.onCompleteHandler[T]), + atMacroPos(asyncBlock.resumeFunTree[T].rhs) + ) + + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + Block(List[Tree]( + stateMachineSpliced, + ValDef(NoMods, name.stateMachine, TypeTree(), Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) + ), + futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + + val isSimple = asyncBlock.asyncStates.size == 1 + if (isSimple) + futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` + else + startStateMachine + } + + def logDiagnostics(anfTree: Tree, states: Seq[String]) { + def location = try { + macroPos.source.path + } catch { + case _: UnsupportedOperationException => + macroPos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } + + /** + * Build final `ClassDef` tree of state machine class. + * + * @param liftables trees of definitions that are lifted to fields of the state machine class + * @param tree `ClassDef` tree of the state machine class + * @param applyBody tree of onComplete handler (`apply` method) + * @param resumeBody RHS of definition tree of `resume` method + * @return transformed `ClassDef` tree of the state machine class + */ + def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = { + val liftedSyms = liftables.map(_.symbol).toSet + val stateMachineClass = tree.symbol + liftedSyms.foreach { + sym => + if (sym != null) { + sym.owner = stateMachineClass + if (sym.isModule) + sym.moduleClass.owner = stateMachineClass + } + } + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will only be run on the RHS of `def foo`. + class UseFields extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = tree match { + case _ if currentOwner == stateMachineClass => + super.transform(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + atOwner(currentOwner) { + val fieldSym = tree.symbol + val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) + changeOwner(set, tree.symbol, currentOwner) + localTyper.typedPos(tree.pos)(set) + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + atPos(tree.pos) { + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) + } + case _ => + super.transform(tree) + } + } + + val liftablesUseFields = liftables.map { + case vd: ValDef => vd + case x => + val useField = new UseFields() + //.substituteSymbols(fromSyms, toSyms) + useField.atOwner(stateMachineClass)(useField.transform(x)) + } + + tree.children.foreach { + t => + new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) + } + val treeSubst = tree + + /* Fixes up DefDef: use lifted fields in `body` */ + def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { + val spliceeAnfFixedOwnerSyms = body + val useField = new UseFields() + val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) + val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) + treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) + } + + liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) + + val result0 = transformAt(treeSubst) { + case t@Template(parents, self, stats) => + (ctx: analyzer.Context) => { + treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) + } + } + val result = transformAt(result0) { + case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) + typedTree + + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) + val res = fixup(dd, changed, ctx) + res + } + result + } +} diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala index 1ade5f0..8700bd6 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/internal/AsyncUtils.scala @@ -1,7 +1,7 @@ /* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal object AsyncUtils { diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala new file mode 100644 index 0000000..b0cd914 --- /dev/null +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -0,0 +1,415 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable.ListBuffer +import collection.mutable +import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api + +trait ExprBuilder { + builder: AsyncMacro => + + import global._ + import defn._ + + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } + + val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Symbol, Int]() + + trait AsyncState { + def state: Int + + def nextStates: List[Int] + + def mkHandlerCaseForState: CaseDef + + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None + + var stats: List[Tree] + + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { + case stat :: Nil => stat + case init :+ last => Block(init, last) + } + } + + /** A sequence of statements that concludes with a unconditional transition to `nextState` */ + final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + extends AsyncState { + + def nextStates: List[Int] = + List(nextState) + + def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" + } + + /** A sequence of statements with a conditional transition to the next state, which will represent + * a branch of an `if` or a `match`. + */ + final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { + override def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats) + + override val toString: String = + s"AsyncStateWithoutAwait #$state, nextStates = $nextStates" + } + + /** A sequence of statements that concludes with an `await` call. The `onComplete` + * handler will unconditionally transition to `nextState`. + */ + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) + extends AsyncState { + + def nextStates: List[Int] = + List(nextState) + + override def mkHandlerCaseForState: CaseDef = { + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree + mkHandlerCase(state, stats :+ callOnComplete) + } + + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { + val tryGetTree = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(futureSystemOps.tryyGet[T](Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * <resultName> = tr.get.asInstanceOf[<resultType>] + * <nextState> + * <mkResumeApply> + * } + */ + val ifIsFailureTree = + If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[futureSystem.Tryy[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), + List(TypeTree(futureSystemOps.tryType[T]))))).tree, + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) + ) + + Some(mkHandlerCase(state, List(ifIsFailureTree))) + } + + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + } + + /* + * Builder for a single state of an async expression. + */ + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { + /* Statements preceding an await call. */ + private val stats = ListBuffer[Tree]() + /** The state of the target of a LabelDef application (while loop jump) */ + private var nextJumpState: Option[Int] = None + + def +=(stat: Tree): this.type = { + stat match { + case Literal(Constant(())) => // This case occurs in do/while + case _ => + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + } + def addStat() = stats += stat + stat match { + case Apply(fun, Nil) => + // labelDefStates belongs to the current ExprBuilder + labelDefStates get fun.symbol match { + case opt @ Some(nextState) => nextJumpState = opt // re-use object + case None => addStat() + } + case _ => addStat() + } + this + } + + def resultWithAwait(awaitable: Awaitable, + nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + } + + def resultSimple(nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + } + + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) + new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) + } + + /** + * Build `AsyncState` ending with a match expression. + * + * The cases of the match simply resume at the state of their corresponding right-hand side. + * + * @param scrutTree tree of the scrutinee + * @param cases list of case definitions + * @param caseStates starting state of the right-hand side of the each case + * @return an `AsyncState` representing the match expression + */ + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { + // 1. build list of changed cases + val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + } + // 2. insert changed match tree at the end of the current state + this += Match(scrutTree, newCases) + new AsyncStateWithoutAwait(stats.toList, state, caseStates) + } + + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait" + } + } + + /** + * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). + * + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + */ + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { + val asyncStates = ListBuffer[AsyncState]() + + var stateBuilder = new AsyncStateBuilder(startState, symLookup) + var currState = startState + + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { + case Apply(fun, _) if isAwait(fun) => true + case _ => false + }) abort(tree.pos, "await must not be used in this position") + + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) + } + + import stateAssigner.nextState + + // populate asyncStates + for (stat <- stats) stat match { + // the val name = await(..) pattern + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val afterAwaitState = nextState() + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) + asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + currState = afterAwaitState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => + checkForUnsupportedAwait(cond) + + val thenStartState = nextState() + val elseStartState = nextState() + val afterIfState = nextState() + + asyncStates += + // the two Int arguments are the start state of the then branch and the else branch, respectively + stateBuilder.resultWithIf(cond, thenStartState, elseStartState) + + List((thenp, thenStartState), (elsep, elseStartState)) foreach { + case (branchTree, state) => + val builder = nestedBlockBuilder(branchTree, state, afterIfState) + asyncStates ++= builder.asyncStates + } + + currState = afterIfState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case Match(scrutinee, cases) if stat exists isAwait => + checkForUnsupportedAwait(scrutinee) + + val caseStates = cases.map(_ => nextState()) + val afterMatchState = nextState() + + asyncStates += + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) + + for ((cas, num) <- cases.zipWithIndex) { + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) + asyncStates ++= builder.asyncStates + } + + currState = afterMatchState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case ld @ LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = nextState() + val afterLabelState = nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) + labelDefStates(ld.symbol) = startLabelState + val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case _ => + checkForUnsupportedAwait(stat) + stateBuilder += stat + } + // complete last state builder (representing the expressions after the last await) + stateBuilder += expr + val lastState = stateBuilder.resultSimple(endState) + asyncStates += lastState + } + + trait AsyncBlock { + def asyncStates: List[AsyncState] + + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T: WeakTypeTag]: DefDef + } + + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName): Tree = + gen.mkAttributedRef(stateMachineMember(name)) + } + + /** + * Uses `AsyncBlockBuilder` to create an instance of `AsyncBlock`. + * + * @param block a `Block` tree in ANF + * @param symLookup helper for looking up members of the state machine class + * @return an `AsyncBlock` + */ + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { + val Block(stats, expr) = block + val startState = stateAssigner.nextState() + val endState = Int.MaxValue + + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) + + new AsyncBlock { + def asyncStates = blockBuilder.asyncStates.toList + + def mkCombinedHandlerCases[T: WeakTypeTag]: List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody)) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + initCases :+ caseForLastState + } + } + + val initStates = asyncStates.init + + /** + * Builds the definition of the `resume` method. + * + * The resulting tree has the following shape: + * + * def resume(): Unit = { + * try { + * state match { + * case 0 => { + * f11 = exprReturningFuture + * f11.onComplete(onCompleteHandler)(context) + * } + * ... + * } + * } catch { + * case NonFatal(t) => result.failure(t) + * } + * } + */ + def resumeFunTree[T: WeakTypeTag]: DefDef = + DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + List( + CaseDef( + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree + })), EmptyTree)) + + /** + * Builds a `match` expression used as an onComplete handler. + * + * Assumes `tr: Try[Any]` is in scope. The resulting tree has the following shape: + * + * state match { + * case 0 => + * x11 = tr.get.asInstanceOf[Double] + * state = 1 + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = + Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + } + } + + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) + + private def mkResumeApply(symLookup: SymLookup) = + Apply(symLookup.memberRef(name.resume), Nil) + + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) + + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) + + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) + + private def literalUnit = Literal(Constant(())) +} diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala new file mode 100644 index 0000000..166affe --- /dev/null +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.language.higherKinds +import scala.reflect.internal.SymbolTable + +/** + * An abstraction over a future system. + * + * Used by the macro implementations in [[scala.async.AsyncBase]] to + * customize the code generation. + * + * The API mirrors that of `scala.concurrent.Future`, see the instance + * [[ScalaConcurrentFutureSystem]] for an example of how + * to implement this. + */ +trait FutureSystem { + /** A container to receive the final value of the computation */ + type Prom[A] + /** A (potentially in-progress) computation */ + type Fut[A] + /** An execution context, required to create or register an on completion callback on a Future. */ + type ExecContext + /** Any data type isomorphic to scala.util.Try. */ + type Tryy[T] + + trait Ops { + val universe: reflect.internal.SymbolTable + + import universe._ + def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) + + def promType[A: WeakTypeTag]: Type + def tryType[A: WeakTypeTag]: Type + def execContextType: Type + + /** Create an empty promise */ + def createProm[A: WeakTypeTag]: Expr[Prom[A]] + + /** Extract a future from the given promise. */ + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] + + /** Construct a future to asynchronously compute the given expression */ + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] + + /** Register an call back to run on completion of the given future */ + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] + + /** Complete a promise with a value */ + def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] + + def spawn(tree: Tree, execContext: Tree): Tree = + future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree + + def tryyIsFailure[A](tryy: Expr[Tryy[A]]): Expr[Boolean] + + def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] + def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] + def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] + + /** A hook for custom macros to transform the tree post-ANF transform */ + def postAnfTransform(tree: Block): Block = tree + } + + def mkOps(c: SymbolTable): Ops { val universe: c.type } +} + +object ScalaConcurrentFutureSystem extends FutureSystem { + + import scala.concurrent._ + + type Prom[A] = Promise[A] + type Fut[A] = Future[A] + type ExecContext = ExecutionContext + type Tryy[A] = scala.util.Try[A] + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] + def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] + def execContextType: Type = weakTypeOf[ExecutionContext] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + Promise[A]() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.future + } + + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { + Future(a.splice)(execContext.splice) + } + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + future.splice.onComplete(fun.splice)(execContext.splice) + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.complete(value.splice) + Expr[Unit](Literal(Constant(()))).splice + } + + def tryyIsFailure[A](tryy: Expr[scala.util.Try[A]]): Expr[Boolean] = reify { + tryy.splice.isFailure + } + + def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] = reify { + tryy.splice.get + } + def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] = reify { + scala.util.Success[A](a.splice) + } + def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] = reify { + scala.util.Failure[A](a.splice) + } + } +} diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala new file mode 100644 index 0000000..7b102d1 --- /dev/null +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -0,0 +1,149 @@ +package scala.async.internal + +trait Lifter { + self: AsyncMacro => + import scala.reflect.internal.Flags._ + import global._ + + /** + * Identify which DefTrees are used (including transitively) which are declared + * in some state but used (including transitively) in another state. + * + * These will need to be lifted to class members of the state machine. + */ + def liftables(asyncStates: List[AsyncState]): List[Tree] = { + object companionship { + private val companions = collection.mutable.Map[Symbol, Symbol]() + private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() + private def record(sym1: Symbol, sym2: Symbol) { + companions(sym1) = sym2 + companions(sym2) = sym1 + } + + def record(defs: List[Tree]) { + // Keep note of local companions so we rename them consistently + // when lifting. + val comps = for { + cd@ClassDef(_, _, _, _) <- defs + md@ModuleDef(_, _, _) <- defs + if (cd.name.toTermName == md.name) + } record(cd.symbol, md.symbol) + } + def companionOf(sym: Symbol): Symbol = { + companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) + } + } + + + val defs: Map[Tree, Int] = { + /** Collect the DefTrees directly enclosed within `t` that have the same owner */ + def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case dt: DefTree => dt :: Nil + case _: Function => Nil + case t => + val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) + companionship.record(childDefs) + childDefs + } + asyncStates.flatMap { + asyncState => + val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) + defs.map((_, asyncState.state)) + }.toMap + } + + // In which block are these symbols defined? + val symToDefiningState: Map[Symbol, Int] = defs.map { + case (k, v) => (k.symbol, v) + } + + // The definitions trees + val symToTree: Map[Symbol, Tree] = defs.map { + case (k, v) => (k.symbol, k) + } + + // The direct references of each definition tree + val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { + case tree => (tree.symbol, tree.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol + }) + }.toMap + + // The direct references of each block, excluding references of `DefTree`-s which + // are already accounted for. + val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { + val refs: List[(Int, Symbol)] = asyncStates.flatMap( + asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + }) + ) + toMultiMap(refs) + } + + def liftableSyms: Set[Symbol] = { + val liftableMutableSet = collection.mutable.Set[Symbol]() + def markForLift(sym: Symbol) { + if (!liftableMutableSet(sym)) { + liftableMutableSet += sym + + // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars + // stays in its original location, so things that it refers to need not be lifted. + if (!(sym.isVal || sym.isVar)) + defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) + } + } + // Start things with DefTrees directly referenced from statements from other states... + val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { + case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) + } + // .. and likewise for DefTrees directly referenced by other DefTrees from other states + val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { + case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) + } + // Mark these for lifting, which will follow transitive references. + (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) + liftableMutableSet.toSet + } + + val lifted = liftableSyms.map(symToTree).toList.map { + t => + val sym = t.symbol + val treeLifted = t match { + case vd@ValDef(_, _, tpt, rhs) => + sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) + sym.name = name.fresh(sym.name.toTermName) + sym.modifyInfo(_.deconst) + val zeroRhs = atPos(t.pos)(gen.mkZero(vd.symbol.info)) + treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(sym.tpe).setPos(t.pos), zeroRhs) + case dd@DefDef(_, _, tparams, vparamss, tpt, rhs) => + sym.name = this.name.fresh(sym.name.toTermName) + sym.setFlag(PRIVATE | LOCAL) + // Was `DefDef(sym, rhs)`, but this ran afoul of `ToughTypeSpec.nestedMethodWithInconsistencyTreeAndInfoParamSymbols` + // due to the handling of type parameter skolems in `thisMethodType` in `Namers` + treeCopy.DefDef(dd, Modifiers(sym.flags), sym.name, tparams, vparamss, tpt, rhs) + case cd@ClassDef(_, _, tparams, impl) => + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + companionship.companionOf(cd.symbol) match { + case NoSymbol => + case moduleSymbol => + moduleSymbol.name = sym.name.toTermName + moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName + } + treeCopy.ClassDef(cd, Modifiers(sym.flags), sym.name, tparams, impl) + case md@ModuleDef(_, _, impl) => + companionship.companionOf(md.symbol) match { + case NoSymbol => + sym.name = name.fresh(sym.name.toTermName) + sym.moduleClass.name = sym.name.toTypeName + case classSymbol => // will be renamed by `case ClassDef` above. + } + treeCopy.ModuleDef(md, Modifiers(sym.flags), sym.name, impl) + case td@TypeDef(_, _, tparams, rhs) => + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs) + } + atPos(t.pos)(treeLifted) + } + lifted + } +} diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala new file mode 100644 index 0000000..8753b3d --- /dev/null +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -0,0 +1,241 @@ +package scala.async.internal + +import reflect.internal.Flags._ + +trait LiveVariables { + self: AsyncMacro => + import global._ + + /** + * Returns for a given state a list of fields (as trees) that should be nulled out + * upon resuming that state (at the beginning of `resume`). + * + * @param asyncStates the states of an `async` block + * @param liftables the lifted fields + * @return a map mapping a state to the fields that should be nulled out + * upon resuming that state + */ + def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = { + // live variables analysis: + // the result map indicates in which states a given field should be nulled out + val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) + + var assignsOf = Map[Int, List[Tree]]() + + for ((fld, where) <- liveVarsMap; state <- where) + assignsOf get state match { + case None => + assignsOf += (state -> List(fld)) + case Some(trees) if !trees.exists(_.symbol == fld.symbol) => + assignsOf += (state -> (fld +: trees)) + case _ => + /* do nothing */ + } + + assignsOf + } + + /** + * Live variables data-flow analysis. + * + * The goal is to find, for each lifted field, the last state where the field is used. + * In all direct successor states which are not (indirect) predecessors of that last state + * (possible through loops), the corresponding field should be nulled out (at the beginning of + * `resume`). + * + * @param asyncStates the states of an `async` block + * @param liftables the lifted fields + * @return a map which indicates for a given field (the key) the states in which it should be nulled out + */ + def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = { + val liftedSyms: Set[Symbol] = // include only vars + liftables.filter { + case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE) + case _ => false + }.map(_.symbol).toSet + + // determine which fields should be live also at the end (will not be nulled out) + val noNull: Set[Symbol] = liftedSyms.filter { sym => + sym.tpe.typeSymbol.isPrimitiveValueClass || liftables.exists { tree => + !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym) + } + } + AsyncUtils.vprintln(s"fields never zero-ed out: ${noNull.mkString(", ")}") + + /** + * Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields. + * + * @param as a state of an `async` expression + * @return a set of lifted fields that are used within state `as` + */ + def fieldsUsedIn(as: AsyncState): ReferencedFields = { + class FindUseTraverser extends AsyncTraverser { + var usedFields = Set[Symbol]() + var capturedFields = Set[Symbol]() + private def capturing[A](body: => A): A = { + val saved = capturing + try { + capturing = true + body + } finally capturing = saved + } + private def capturingCheck(tree: Tree) = capturing(tree foreach check) + private var capturing: Boolean = false + private def check(tree: Tree) { + tree match { + case Ident(_) if liftedSyms(tree.symbol) => + if (capturing) + capturedFields += tree.symbol + else + usedFields += tree.symbol + case _ => + } + } + override def traverse(tree: Tree) = { + check(tree) + super.traverse(tree) + } + + override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef) + + override def nestedModule(module: ModuleDef): Unit = capturingCheck(module) + + override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef) + + override def byNameArgument(arg: Tree): Unit = capturingCheck(arg) + + override def function(function: Function): Unit = capturingCheck(function) + + override def patMatFunction(tree: Match): Unit = capturingCheck(tree) + } + + val findUses = new FindUseTraverser + findUses.traverse(Block(as.stats: _*)) + ReferencedFields(findUses.usedFields, findUses.capturedFields) + } + case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) { + override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}" + } + + /* Build the control-flow graph. + * + * A state `i` is contained in the list that is the value to which + * key `j` maps iff control can flow from state `j` to state `i`. + */ + val cfg: Map[Int, List[Int]] = asyncStates.map(as => (as.state -> as.nextStates)).toMap + + /** Tests if `state1` is a predecessor of `state2`. + */ + def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean = + if (seen(state1)) false // breaks cycles in the CFG + else cfg get state1 match { + case Some(nextStates) => + nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1)) + case None => + false + } + + val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get + + for (as <- asyncStates) + AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}") + + /* Backwards data-flow analysis. Computes live variables information at entry and exit + * of each async state. + * + * Compute using a simple fixed point iteration: + * + * 1. currStates = List(finalState) + * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs + * 3. record if LVentry(cs) has changed for some cs. + * 4. obtain predecessors pred of each cs \in currStates + * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors + * 6. currStates = pred + * 7. repeat if something has changed + */ + + var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + + // All fields are declared to be dead at the exit of the final async state, except for the ones + // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def. + LVexit = LVexit + (finalState.state -> noNull) + + var currStates = List(finalState) // start at final state + var pred = List[AsyncState]() // current predecessor states + var hasChanged = true // if something has changed we need to continue iterating + var captured: Set[Symbol] = Set() + + while (hasChanged) { + hasChanged = false + + for (cs <- currStates) { + val LVentryOld = LVentry(cs.state) + val referenced = fieldsUsedIn(cs) + captured ++= referenced.captured + val LVentryNew = LVexit(cs.state) ++ referenced.used + if (!LVentryNew.sameElements(LVentryOld)) { + LVentry = LVentry + (cs.state -> LVentryNew) + hasChanged = true + } + } + + pred = currStates.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state))) + + for (p <- pred) { + val LVexitOld = LVexit(p.state) + val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet + if (!LVexitNew.sameElements(LVexitOld)) { + LVexit = LVexit + (p.state -> LVexitNew) + hasChanged = true + } + } + + currStates = pred + } + + for (as <- asyncStates) { + AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}") + AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") + } + + def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = + if (avoid(at)) Set() + else if (captured(field.symbol)) { + Set() + } + else LVentry get at.state match { + case Some(fields) if fields.exists(_ == field.symbol) => + Set(at.state) + case _ => + val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet + preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) + } + + val lastUsages: Map[Tree, Set[Int]] = + liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap + + for ((fld, lastStates) <- lastUsages) + AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}") + + val nullOutAt: Map[Tree, Set[Int]] = + for ((fld, lastStates) <- lastUsages) yield { + val killAt = lastStates.flatMap { s => + if (s == finalState.state) Set() + else { + val lastAsyncState = asyncStates.find(_.state == s).get + val succNums = lastAsyncState.nextStates + // all successor states that are not indirect predecessors + // filter out successor states where the field is live at the entry + succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).exists(_ == fld.symbol)) + } + } + (fld, killAt) + } + + for ((fld, killAt) <- nullOutAt) + AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}") + + nullOutAt + } +} diff --git a/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala b/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala new file mode 100644 index 0000000..4e1a0af --- /dev/null +++ b/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala @@ -0,0 +1,18 @@ +package scala +package async +package internal + +import scala.language.experimental.macros +import scala.reflect.macros.Context +import scala.concurrent.Future + +object ScalaConcurrentAsync extends AsyncBase { + type FS = ScalaConcurrentFutureSystem.type + val futureSystem: FS = ScalaConcurrentFutureSystem + + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } +} diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index bc60a6d..cdde7a4 100644 --- a/src/main/scala/scala/async/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -2,7 +2,7 @@ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal private[async] final class StateAssigner { private var current = -1 @@ -11,4 +11,4 @@ private[async] final class StateAssigner { current += 1 current } -}
\ No newline at end of file +} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala new file mode 100644 index 0000000..e382c62 --- /dev/null +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -0,0 +1,275 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import reflect.ClassTag +import scala.reflect.macros.runtime.AbortMacroException + +/** + * Utilities used in both `ExprBuilder` and `AnfTransform`. + */ +private[async] trait TransformUtils { + self: AsyncMacro => + + import global._ + + object name { + val resume = newTermName("resume") + val apply = newTermName("apply") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" + + val state = newTermName("state") + val result = newTermName("result") + val execContext = newTermName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val tr = newTermName("tr") + val t = newTermName("throwable") + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString + } + + def isAwait(fun: Tree) = + fun.symbol == defn.Async_await + + private lazy val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + private def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val paramss = fun.tpe.paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + private def argName(fun: Tree): ((Int, Int) => String) = { + val paramss = fun.tpe.paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { + self.splice.contains(elem.splice) + } + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val TryClass = rootMirror.staticClass("scala.util.Try") + val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) + val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val Async_await = asyncBase.awaitMethod(global)(macroApplication.symbol).ensuring(_ != NoSymbol) + } + + def isSafeToInline(tree: Tree) = { + treeInfo.isExprSafeToInline(tree) + } + + // `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops; + // we must break that `If` into states so that it convert the label jump into a state machine + // transition + final def containsForiegnLabelJump(t: Tree): Boolean = { + val labelDefs = t.collect { + case ld: LabelDef => ld.symbol + }.toSet + t.exists { + case rt: RefTree => !(labelDefs contains rt.symbol) + case _ => false + } + } + + /** Map a list of arguments to: + * - A list of argument Trees + * - A list of auxillary results. + * + * The function unwraps and rewraps the `arg :_*` construct. + * + * @param args The original argument trees + * @param f A function from argument (with '_*' unwrapped) and argument index to argument. + * @tparam A The type of the auxillary result + */ + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + args match { + case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => + val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip + val exprs = argExprs :+ atPos(lastArgExpr.pos.makeTransparent)(Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR))) + (a, exprs) + case args => + args.zipWithIndex.map(f.tupled).unzip + } + } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } + + + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descent stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def nestedMethod(defdef: DefDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + def patMatFunction(tree: Match) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + val isInByName = isByName(fun) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) + + abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { + currentOwner = callSiteTyper.context.owner + curTree = EmptyTree + + def currOwner: Symbol = currentOwner + + localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + } + + def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { + object trans extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = { + if (f.isDefinedAt(tree)) { + f(tree)(localTyper.context) + } else super.transform(tree) + } + } + trans.transform(tree) + } + + def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { + new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) + tree + } + + class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) + extends ChangeOwnerTraverser(oldowner, newowner) { + + override def traverse(tree: Tree) { + tree match { + case _: DefTree => change(tree.symbol.moduleClass) + case _ => + } + super.traverse(tree) + } + } + + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + + // Attributed version of `TreeGen#mkCastPreservingAnnotations` + def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { + atPos(tree.pos) { + val casted = gen.mkAttributedCast(tree, uncheckedBounds(tp.withoutAnnotations).dealias) + Typed(casted, TypeTree(tp)).setType(tp) + } + } + + // ===================================== + // Copy/Pasted from Scala 2.10.3. See SI-7694. + private lazy val UncheckedBoundsClass = { + global.rootMirror.getClassIfDefined("scala.reflect.internal.annotations.uncheckedBounds") + } + final def uncheckedBounds(tp: Type): Type = { + if (tp.typeArgs.isEmpty || UncheckedBoundsClass == NoSymbol) tp + else tp.withAnnotation(AnnotationInfo marker UncheckedBoundsClass.tpe) + } + // ===================================== +} diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index deaee03..b42726b 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -4,20 +4,18 @@ package scala.async -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId import AsyncId._ import tools.reflect.ToolBox -@RunWith(classOf[JUnit4]) class TreeInterrogation { @Test def `a minimal set of vals are lifted to vars`() { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes") + val tb = mkToolbox(s"-cp ${toolboxClasspath}") val tree = tb.parse( - """| import _root_.scala.async.AsyncId._ + """| import _root_.scala.async.internal.AsyncId._ | async { | val x = await(1) | val y = x * 2 @@ -40,8 +38,7 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1$1", "await$2$1")) val defDefs = tree1.collect { case t: Template => @@ -52,7 +49,7 @@ class TreeInterrogation { && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name } }.flatten - defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume$async", "<init>")) + defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume", "<init>")) } } @@ -68,17 +65,19 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten") - import scala.async.Async._ + val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid") + import scala.async.internal.AsyncId._ val tree = tb.parse( - """ import _root_.scala.async.AsyncId.{async, await} - | def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" - | val res = async { - | var i = 0 - | def get = async {i += 1; i} - | foo[Int](await(get))(await(get) :: Nil : _*) + """ + | import scala.async.internal.AsyncId._ + | async { + | var b = true + | while(await(b)) { + | b = false + | } + | await(b) | } - | res + | | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala index 2569303..ae346a2 100644 --- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala +++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala @@ -5,121 +5,30 @@ package scala.async package neg -/** - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId -@RunWith(classOf[JUnit4]) class LocalClasses0Spec { - @Test - def `reject a local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | } - """.stripMargin - } + def localClassCrashIssue16() { + import AsyncId.{async, await} + async { + class B { def f = 1 } + await(new B()).f + } mustBe 1 } @Test - def `reject a local class 2`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | val fut = Future { 5 } - | val x = await(fut) - | x - | } - """.stripMargin - } + def nestedCaseClassAndModuleAllowed() { + import AsyncId.{await, async} + async { + trait Base { def base = 0} + await(0) + case class Person(name: String) extends Base + val fut = async { "bob" } + val x = Person(await(fut)) + x.base + x.name + } mustBe "bob" } - - @Test - def `reject a local class 3`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class Person(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a local class with symbols in its name`() { - expectError("Local case class :: illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class ::(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a nested local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = 2 + 2 - | var y = 0 - | if (x > 0) { - | case class Person(name: String) - | y = await(fut) - | } else { - | y = x - | } - | y - | } - """.stripMargin - } - } - - @Test - def `reject a local singleton object`() { - expectError("Local object Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | object Person { val name = "Joe" } - | } - """.stripMargin - } - } - } diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index b0d5fde..deac069 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -5,11 +5,8 @@ package scala.async package neg -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test -@RunWith(classOf[JUnit4]) class NakedAwait { @Test def `await only allowed in async neg`() { @@ -25,7 +22,7 @@ class NakedAwait { def `await not allowed in by-name argument`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(a: Int)(b: => Int) = 0 | async { foo(0)(await(0)) } """.stripMargin @@ -36,7 +33,7 @@ class NakedAwait { def `await not allowed in boolean short circuit argument 1`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { true && await(false) } """.stripMargin } @@ -46,7 +43,7 @@ class NakedAwait { def `await not allowed in boolean short circuit argument 2`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { true || await(false) } """.stripMargin } @@ -56,7 +53,7 @@ class NakedAwait { def nestedObject() { expectError("await must not be used under a nested object.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { object Nested { await(false) } } """.stripMargin } @@ -66,7 +63,7 @@ class NakedAwait { def nestedTrait() { expectError("await must not be used under a nested trait.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { trait Nested { await(false) } } """.stripMargin } @@ -76,7 +73,7 @@ class NakedAwait { def nestedClass() { expectError("await must not be used under a nested class.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { class Nested { await(false) } } """.stripMargin } @@ -86,7 +83,7 @@ class NakedAwait { def nestedFunction() { expectError("await must not be used under a nested function.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { () => { await(false) } } """.stripMargin } @@ -96,7 +93,7 @@ class NakedAwait { def nestedPatMatFunction() { expectError("await must not be used under a nested class.") { // TODO more specific error message """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { { case x => { await(false) } } : PartialFunction[Any, Any] } """.stripMargin } @@ -106,7 +103,7 @@ class NakedAwait { def tryBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { await(false) } catch { case _ => } } """.stripMargin } @@ -116,7 +113,7 @@ class NakedAwait { def catchBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { () } catch { case _ => await(false) } } """.stripMargin } @@ -126,17 +123,27 @@ class NakedAwait { def finallyBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { () } finally { await(false) } } """.stripMargin } } @Test + def guard() { + expectError("await must not be used under a pattern guard.") { + """ + | import _root_.scala.async.internal.AsyncId._ + | async { 1 match { case _ if await(true) => } } + """.stripMargin + } + } + + @Test def nestedMethod() { expectError("await must not be used under a nested method.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { def foo = await(false) } """.stripMargin } @@ -146,7 +153,7 @@ class NakedAwait { def returnIllegal() { expectError("return is illegal") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(): Any = async { return false } | () | @@ -158,7 +165,7 @@ class NakedAwait { def lazyValIllegal() { expectError("lazy vals are illegal") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(): Any = async { val x = { lazy val y = 0; y } } | () | diff --git a/src/test/scala/scala/async/neg/SampleNegSpec.scala b/src/test/scala/scala/async/neg/SampleNegSpec.scala index 76f9c3e..e57dae9 100644 --- a/src/test/scala/scala/async/neg/SampleNegSpec.scala +++ b/src/test/scala/scala/async/neg/SampleNegSpec.scala @@ -5,11 +5,8 @@ package scala.async package neg -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test -@RunWith(classOf[JUnit4]) class SampleNegSpec { @Test def `missing symbol`() { diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala index 4a7a958..98d2256 100644 --- a/src/test/scala/scala/async/package.scala +++ b/src/test/scala/scala/async/package.scala @@ -42,7 +42,24 @@ package object async { m.mkToolBox(options = compileOptions) } - def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) { + def scalaBinaryVersion: String = { + val PreReleasePattern = """.*-(M|RC).*""".r + val Pattern = """(\d+\.\d+)\..*""".r + scala.util.Properties.versionNumberString match { + case s @ PreReleasePattern(_) => s + case Pattern(v) => v + case _ => "" + } + } + + def toolboxClasspath = { + val f = new java.io.File(s"target/scala-${scalaBinaryVersion}/classes") + if (!f.exists) sys.error(s"output directory ${f.getAbsolutePath} does not exist.") + f.getAbsolutePath + } + + def expectError(errorSnippet: String, compileOptions: String = "", + baseCompileOptions: String = s"-cp ${toolboxClasspath}")(code: String) { intercept[ToolBoxError] { eval(code, compileOptions + " " + baseCompileOptions) }.getMessage mustContain errorSnippet diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 7be6299..728f33b 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -11,8 +11,7 @@ import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 +import scala.async.internal.AsyncId class AnfTestClass { @@ -69,7 +68,6 @@ object State { @volatile var result: Int = 0 } -@RunWith(classOf[JUnit4]) class AnfTransformSpec { @Test @@ -114,8 +112,6 @@ class AnfTransformSpec { @Test def `inlining block does not produce duplicate definition`() { - import scala.async.AsyncId - AsyncId.async { val f = 12 val x = AsyncId.await(f) @@ -132,8 +128,6 @@ class AnfTransformSpec { @Test def `inlining block in tail position does not produce duplicate definition`() { - import scala.async.AsyncId - AsyncId.async { val f = 12 val x = AsyncId.await(f) @@ -176,7 +170,7 @@ class AnfTransformSpec { @Test def nestedAwaitAsBareExpression() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { await(await("").isEmpty) } @@ -186,7 +180,7 @@ class AnfTransformSpec { @Test def nestedAwaitInBlock() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { () await(await("").isEmpty) @@ -197,7 +191,7 @@ class AnfTransformSpec { @Test def nestedAwaitInIf() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { if ("".isEmpty) await(await("").isEmpty) @@ -208,7 +202,7 @@ class AnfTransformSpec { @Test def byNameExpressionsArentLifted() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(ignored: => Any, b: Int) = b val result = async { foo(???, await(1)) @@ -218,7 +212,7 @@ class AnfTransformSpec { @Test def evaluationOrderRespected() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 @@ -233,19 +227,19 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSection1() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0" val res = async { var i = 0 def get = {i += 1; i} - foo(get)(get) + foo(get)(await(get)) } res mustBe "a0 = 1, b0 = 2" } @Test def awaitInNonPrimaryParamSection2() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" val res = async { var i = 0 @@ -257,7 +251,7 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSectionWithLazy1() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: => Int)(b: Int) = b val res = async { def get = async {0} @@ -268,7 +262,7 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSectionWithLazy2() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: Int)(b: => Int) = a val res = async { def get = async {0} @@ -279,7 +273,7 @@ class AnfTransformSpec { @Test def awaitWithLazy() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: Int, b: => Int) = a val res = async { def get = async {0} @@ -290,7 +284,7 @@ class AnfTransformSpec { @Test def awaitOkInReciever() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} class Foo { def bar(a: Int)(b: Int) = a + b } async { await(async(new Foo)).bar(1)(2) @@ -299,7 +293,7 @@ class AnfTransformSpec { @Test def namedArgumentsRespectEvaluationOrder() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 @@ -314,7 +308,7 @@ class AnfTransformSpec { @Test def namedAndDefaultArgumentsRespectEvaluationOrder() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def next() = { i += 1; @@ -332,7 +326,7 @@ class AnfTransformSpec { @Test def repeatedParams1() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def foo(a: Int, b: Int*) = b.toList def id(i: Int) = i @@ -343,7 +337,7 @@ class AnfTransformSpec { @Test def repeatedParams2() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def foo(a: Int, b: Int*) = b.toList def id(i: Int) = i @@ -351,4 +345,64 @@ class AnfTransformSpec { foo(await(0), List(id(1), id(2), id(3)): _*) } mustBe (List(1, 2, 3)) } + + @Test + def awaitInThrow() { + import _root_.scala.async.internal.AsyncId.{async, await} + intercept[Exception]( + async { + throw new Exception("msg: " + await(0)) + } + ).getMessage mustBe "msg: 0" + } + + @Test + def awaitInTyped() { + import _root_.scala.async.internal.AsyncId.{async, await} + async { + (("msg: " + await(0)): String).toString + } mustBe "msg: 0" + } + + + @Test + def awaitInAssign() { + import _root_.scala.async.internal.AsyncId.{async, await} + async { + var x = 0 + x = await(1) + x + } mustBe 1 + } + + @Test + def caseBodyMustBeTypedAsUnit() { + import _root_.scala.async.internal.AsyncId.{async, await} + val Up = 1 + val Down = 2 + val sign = async { + await(1) match { + case Up => 1.0 + case Down => -1.0 + } + } + sign mustBe 1.0 + } + + @Test + def awaitInImplicitApply() { + val tb = mkToolbox(s"-cp ${toolboxClasspath}") + val tree = tb.typeCheck(tb.parse { + """ + | import language.implicitConversions + | import _root_.scala.async.internal.AsyncId.{async, await} + | implicit def view(a: Int): String = "" + | async { + | await(0).length + | } + """.stripMargin + }) + val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x } + applyImplicitView.map(_.toString) mustBe List("view(a$1)") + } } diff --git a/src/test/scala/scala/async/run/await0/Await0Spec.scala b/src/test/scala/scala/async/run/await0/Await0Spec.scala index 111602a..2adaa09 100644 --- a/src/test/scala/scala/async/run/await0/Await0Spec.scala +++ b/src/test/scala/scala/async/run/await0/Await0Spec.scala @@ -15,8 +15,6 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test class Await0Class { @@ -63,7 +61,6 @@ class Await0Class { } } -@RunWith(classOf[JUnit4]) class Await0Spec { @Test diff --git a/src/test/scala/scala/async/run/block0/AsyncSpec.scala b/src/test/scala/scala/async/run/block0/AsyncSpec.scala index 5f38086..677cce8 100644 --- a/src/test/scala/scala/async/run/block0/AsyncSpec.scala +++ b/src/test/scala/scala/async/run/block0/AsyncSpec.scala @@ -11,8 +11,6 @@ import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 class Test1Class { @@ -39,7 +37,6 @@ class Test1Class { } -@RunWith(classOf[JUnit4]) class AsyncSpec { @Test diff --git a/src/test/scala/scala/async/run/block1/block1.scala b/src/test/scala/scala/async/run/block1/block1.scala index bf9b56f..f42b073 100644 --- a/src/test/scala/scala/async/run/block1/block1.scala +++ b/src/test/scala/scala/async/run/block1/block1.scala @@ -11,8 +11,6 @@ import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 class Test1Class { @@ -32,7 +30,6 @@ class Test1Class { } } -@RunWith(classOf[JUnit4]) class Block1Spec { @Test def `support a simple await`() { diff --git a/src/test/scala/scala/async/run/cps/CPSSpec.scala b/src/test/scala/scala/async/run/cps/CPSSpec.scala deleted file mode 100644 index b56c6ad..0000000 --- a/src/test/scala/scala/async/run/cps/CPSSpec.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async -package run -package cps - -import scala.concurrent.{Future, Promise, ExecutionContext, future, Await} -import scala.concurrent.duration._ -import scala.async.continuations.CPSBasedAsync._ -import scala.util.continuations._ - -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.junit.Test - -@RunWith(classOf[JUnit4]) -class CPSSpec { - - import ExecutionContext.Implicits.global - - def m1(y: Int): Future[Int] = async { - val f = future { y + 2 } - val f2 = future { y + 3 } - val x1 = await(f) - val x2 = await(f2) - x1 + x2 - } - - def m2(y: Int): Future[Int] = async { - val f = future { y + 2 } - val res = await(f) - if (y > 0) res + 2 - else res - 2 - } - - @Test - def testCPSFallback() { - val fut1 = m1(10) - val res1 = Await.result(fut1, 2.seconds) - assert(res1 == 25, s"expected 25, got $res1") - - val fut2 = m2(10) - val res2 = Await.result(fut2, 2.seconds) - assert(res2 == 14, s"expected 14, got $res2") - } - -} diff --git a/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala b/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala index 733ea01..b417dd6 100644 --- a/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala +++ b/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala @@ -14,10 +14,7 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -@RunWith(classOf[JUnit4]) class ExceptionsSpec { @Test diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala index 01c8620..491b43f 100644 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ b/src/test/scala/scala/async/run/futures/FutureSpec.scala @@ -18,10 +18,7 @@ import scala.util.{Try,Success,Failure} import scala.async.Async.{async, await} import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -@RunWith(classOf[JUnit4]) class FutureSpec { /* some utils */ diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala index 9d1df21..f11d21e 100644 --- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala +++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala @@ -7,13 +7,11 @@ package run package hygiene import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 +import scala.async.internal.AsyncId -@RunWith(classOf[JUnit4]) class HygieneSpec { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} @Test def `is hygenic`() { diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala index e2b1ca6..4a86d82 100644 --- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala +++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala @@ -10,9 +10,8 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId class TestIfElseClass { @@ -38,7 +37,6 @@ class TestIfElseClass { } -@RunWith(classOf[JUnit4]) class IfElseSpec { @Test def `support await in a simple if-else expression`() { diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala index 1f1033a..76e6b1e 100644 --- a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala +++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala @@ -6,11 +6,9 @@ package scala.async package run package ifelse0 -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId -@RunWith(classOf[JUnit4]) class WhileSpec { @Test @@ -64,4 +62,58 @@ class WhileSpec { } result mustBe (100) } -}
\ No newline at end of file + + @Test + def whileExpr() { + import AsyncId._ + + val result = async { + var cond = true + while (cond) { + cond = false + await { 22 } + } + } + result mustBe () + } + + @Test def doWhile() { + import AsyncId._ + val result = async { + var b = 0 + var x = "" + await(do { + x += "1" + x += await("2") + x += "3" + b += await(1) + } while (b < 2)) + await(x) + } + result mustBe "123123" + } + + @Test def whileAwaitCondition() { + import AsyncId._ + val result = async { + var b = true + while(await(b)) { + b = false + } + await(b) + } + result mustBe false + } + + @Test def doWhileAwaitCondition() { + import AsyncId._ + val result = async { + var b = true + do { + b = false + } while(await(b)) + b + } + result mustBe false + } +} diff --git a/src/test/scala/scala/async/run/ifelse1/IfElse1.scala b/src/test/scala/scala/async/run/ifelse1/IfElse1.scala index b567ee6..41b81a4 100644 --- a/src/test/scala/scala/async/run/ifelse1/IfElse1.scala +++ b/src/test/scala/scala/async/run/ifelse1/IfElse1.scala @@ -10,8 +10,6 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test @@ -91,7 +89,6 @@ class TestIfElse1Class { } } -@RunWith(classOf[JUnit4]) class IfElse1Spec { @Test diff --git a/src/test/scala/scala/async/run/ifelse2/ifelse2.scala b/src/test/scala/scala/async/run/ifelse2/ifelse2.scala index 92a76e4..3fc4d3b 100644 --- a/src/test/scala/scala/async/run/ifelse2/ifelse2.scala +++ b/src/test/scala/scala/async/run/ifelse2/ifelse2.scala @@ -10,8 +10,6 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test @@ -37,7 +35,6 @@ class TestIfElse2Class { } } -@RunWith(classOf[JUnit4]) class IfElse2Spec { @Test diff --git a/src/test/scala/scala/async/run/ifelse3/IfElse3.scala b/src/test/scala/scala/async/run/ifelse3/IfElse3.scala index 8a2ab13..8e6e1bb 100644 --- a/src/test/scala/scala/async/run/ifelse3/IfElse3.scala +++ b/src/test/scala/scala/async/run/ifelse3/IfElse3.scala @@ -10,8 +10,6 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test @@ -40,7 +38,6 @@ class TestIfElse3Class { } -@RunWith(classOf[JUnit4]) class IfElse3Spec { @Test diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala new file mode 100644 index 0000000..7d62f80 --- /dev/null +++ b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala @@ -0,0 +1,266 @@ +/* + * Copyright (C) 2012-2013 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async +package run +package live + +import org.junit.Test + +import internal.AsyncTestLV +import AsyncTestLV._ + +case class Cell[T](v: T) + +class Meter(val len: Long) extends AnyVal + +case class MCell[T](var v: T) + + +class LiveVariablesSpec { + AsyncTestLV.clear() + + @Test + def `zero out fields of reference type`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Cell[Int] = + async { Cell(x.v + 1) } + + def m2(x: Cell[Int]): String = + async { x.v.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$1$1 + // a == Cell(1) + val b: Cell[Int] = await(m1(a)) // await$2$1 + // b == Cell(2) + assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))), AsyncTestLV.log) + val res = await(m2(b)) // await$3$1 + assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> Cell(2)))) + res + } + + assert(m3() == "2") + } + + @Test + def `zero out fields of type Any`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Cell[Int] = + async { Cell(x.v + 1) } + + def m2(x: Any): String = + async { x.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$4$1 + // a == Cell(1) + val b: Any = await(m1(a)) // await$5$1 + // b == Cell(2) + assert(AsyncTestLV.log.exists(_ == ("await$4$1" -> Cell(1)))) + val res = await(m2(b)) // await$6$1 + assert(AsyncTestLV.log.exists(_ == ("await$5$1" -> Cell(2)))) + res + } + + assert(m3() == "Cell(2)") + } + + @Test + def `do not zero out fields of primitive type`() { + val f = async { 1 } + + def m1(x: Int): Cell[Int] = + async { Cell(x + 1) } + + def m2(x: Any): String = + async { x.toString } + + def m3() = async { + val a: Int = await(f) // await$7$1 + // a == 1 + val b: Any = await(m1(a)) // await$8$1 + // b == Cell(2) + assert(!AsyncTestLV.log.exists(p => p._1 == "await$7$1")) + val res = await(m2(b)) // await$9$1 + assert(AsyncTestLV.log.exists(_ == ("await$8$1" -> Cell(2)))) + res + } + + assert(m3() == "Cell(2)") + } + + @Test + def `zero out fields of value class type`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Meter = + async { new Meter(x.v + 1) } + + def m2(x: Any): String = + async { x.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$10$1 + // a == Cell(1) + val b: Meter = await(m1(a)) // await$11$1 + // b == Meter(2) + assert(AsyncTestLV.log.exists(_ == ("await$10$1" -> Cell(1)))) + val res = await(m2(b.len)) // await$12$1 + assert(AsyncTestLV.log.exists(entry => entry._1 == "await$11$1" && entry._2.asInstanceOf[Meter].len == 2L)) + res + } + + assert(m3() == "2") + } + + @Test + def `zero out fields after use in loop`() { + val f = async { MCell(1) } + + def m1(x: MCell[Int], y: Int): Int = + async { x.v + y } + + def m3() = async { + // state #1 + val a: MCell[Int] = await(f) // await$13$1 + // state #2 + var y = MCell(0) + + while (a.v < 10) { + // state #4 + a.v = a.v + 1 + y = MCell(await(a).v + 1) // await$14$1 + // state #7 + } + + // state #3 + assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1")) + + val b = await(m1(a, y.v)) // await$15$1 + // state #8 + assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))), AsyncTestLV.log) + assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11)))) + b + } + + assert(m3() == 21, m3()) + } + + @Test + def `don't zero captured fields captured lambda`() { + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = () => assert(x != null) + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields captured by-name`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = func0(assert(x != null)) + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested class`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = new Function0[Unit] { + def apply = assert(x != null) + } + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested object`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + object f extends Function0[Unit] { + def apply = assert(x != null) + } + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested def`() { + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + def xx = x + val f = xx _ + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `capture bug`() { + sealed trait Base + case class B1() extends Base + case class B2() extends Base + val outer = List[(Base, Int)]((B1(), 8)) + + def getMore(b: Base) = 4 + + def baz = async { + outer.head match { + case (a @ B1(), r) => { + val ents = await(getMore(a)) + + { () => + println(a) + assert(a ne null) + } + } + case (b @ B2(), x) => + () => ??? + } + } + baz() + } +} diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala index 7624838..418275e 100644 --- a/src/test/scala/scala/async/run/match0/Match0.scala +++ b/src/test/scala/scala/async/run/match0/Match0.scala @@ -10,9 +10,8 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent.{Future, ExecutionContext, future, Await} import scala.concurrent.duration._ import scala.async.Async.{async, await} -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId class TestMatchClass { @@ -53,7 +52,6 @@ class TestMatchClass { } -@RunWith(classOf[JUnit4]) class MatchSpec { @Test def `support await in a simple match expression`() { @@ -111,4 +109,38 @@ class MatchSpec { } result mustBe (3) } + + @Test def duplicateBindName() { + import AsyncId.{async, await} + def m4(m: Any) = async { + m match { + case buf: String => + await(0) + case buf: Double => + await(2) + } + } + m4("") mustBe 0 + } + + @Test def bugCastBoxedUnitToStringMatch() { + import scala.async.internal.AsyncId.{async, await} + def foo = async { + val p2 = await(5) + "foo" match { + case p3: String => + p2.toString + } + } + foo mustBe "5" + } + + @Test def bugCastBoxedUnitToStringIf() { + import scala.async.internal.AsyncId.{async, await} + def foo = async { + val p2 = await(5) + if (true) p2.toString else p2.toString + } + foo mustBe "5" + } } diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala index ee0a78e..69e741d 100644 --- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala +++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala @@ -2,11 +2,9 @@ package scala.async package run package nesteddef -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId -@RunWith(classOf[JUnit4]) class NestedDef { @Test @@ -37,4 +35,60 @@ class NestedDef { } result mustBe ((0d, 44d, 2)) } + + // We must lift `foo` and `bar` in the next two tests. + @Test + def nestedDefTransitive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + foo + } + result mustBe 0 + } + + @Test + def nestedDefTransitive2() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + 0 + } + result mustBe 0 + } + + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + bar + } + result mustBe 0 + } + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive2() { + import AsyncId._ + val result = async { + val a = 0 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + val x = await(a) - 1 + bar + } + result mustBe 0 + } } diff --git a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala index e2c69d0..0adb506 100644 --- a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala +++ b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala @@ -6,12 +6,10 @@ package scala.async package run package noawait +import scala.async.internal.AsyncId import AsyncId._ import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -@RunWith(classOf[JUnit4]) class NoAwaitSpec { @Test def `async block without await`() { diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala index 83f5a2d..1551856 100644 --- a/src/test/scala/scala/async/run/toughtype/ToughType.scala +++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala @@ -10,9 +10,8 @@ import language.{reflectiveCalls, postfixOps} import scala.concurrent._ import scala.concurrent.duration._ import scala.async.Async._ -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 +import org.junit.{Assert, Test} +import scala.async.internal.AsyncId object ToughTypeObject { @@ -28,7 +27,6 @@ object ToughTypeObject { } } -@RunWith(classOf[JUnit4]) class ToughTypeSpec { @Test def `propogates tough types`() { @@ -67,4 +65,131 @@ class ToughTypeSpec { await(f(2)) } mustBe 3 } + + @Test def existentialBindIssue19() { + import AsyncId.{await, async} + def m7(a: Any) = async { + a match { + case s: Seq[_] => + val x = s.size + var ss = s + ss = s + await(x) + } + } + m7(Nil) mustBe 0 + } + + @Test def existentialBind2Issue19() { + import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global + def conjure[T]: T = null.asInstanceOf[T] + + def m3 = async { + val p: List[Option[_]] = conjure[List[Option[_]]] + await(future(1)) + } + + def m4 = async { + await(future[List[_]](Nil)) + } + } + + @Test def singletonTypeIssue17() { + import AsyncId.{async, await} + class A { class B } + async { + val a = new A + def foo(b: a.B) = 0 + await(foo(new a.B)) + } + } + + @Test def existentialMatch() { + import AsyncId.{async, await} + trait Container[+A] + case class ContainerImpl[A](value: A) extends Container[A] + def foo: Container[_] = async { + val a: Any = List(1) + a match { + case buf: Seq[_] => + val foo = await(5) + val e0 = buf(0) + ContainerImpl(e0) + } + } + foo + } + + @Test def existentialIfElse0() { + import AsyncId.{async, await} + trait Container[+A] + case class ContainerImpl[A](value: A) extends Container[A] + def foo: Container[_] = async { + val a: Any = List(1) + if (true) { + val buf: Seq[_] = List(1) + val foo = await(5) + val e0 = buf(0) + ContainerImpl(e0) + } else ??? + } + foo + } + + // This test was failing when lifting `def r` with: + // symbol value m#10864 does not exist in r$1 + // + // We generated: + // + // private[this] def r$1#5727[A#5728 >: Nothing#157 <: Any#156](m#5731: Foo#2349[A#5728]): Unit#208 = Bippy#2352.this.bar#5532({ + // m#5730; + // () + // }); + // + // Notice the incorrect reference to `m`. + // + // We compensated in `Lifter` by copying `ValDef` parameter symbols directly across. + // + // Turns out the behaviour stems from `thisMethodType` in `Namers`, which treats type parameter skolem symbols. + @Test def nestedMethodWithInconsistencyTreeAndInfoParamSymbols() { + import language.{reflectiveCalls, postfixOps} + import scala.concurrent.{Future, ExecutionContext, future, Await} + import scala.concurrent.duration._ + import scala.async.Async.{async, await} + import scala.async.internal.AsyncId + + class Foo[A] + + object Bippy { + + import ExecutionContext.Implicits.global + + def bar(f: => Unit): Unit = f + + def quux: Future[String] = ??? + + def foo = async { + def r[A](m: Foo[A])(n: A) = { + bar { + locally(m) + locally(n) + identity[A] _ + } + } + + await(quux) + + r(new Foo[String])("") + } + } + Bippy + } +} + +trait A + +trait B + +trait L[A2, B2 <: A2] { + def bar(a: Any, b: Any) = 0 } diff --git a/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala b/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala new file mode 100644 index 0000000..5eb1f32 --- /dev/null +++ b/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala @@ -0,0 +1,57 @@ +package scala.async +package run +package uncheckedBounds + +import org.junit.{Test, Assert} +import scala.async.TreeInterrogation + +class UncheckedBoundsSpec { + @Test def insufficientLub_SI_7694() { + suppressingFailureBefore2_10_3 { + eval( s""" + object Test { + import _root_.scala.async.run.toughtype._ + import _root_.scala.async.internal.AsyncId.{async, await} + async { + (if (true) await(null: L[A, A]) else await(null: L[B, B])) + } + } + """, compileOptions = s"-cp ${toolboxClasspath} ") + } + } + + @Test def insufficientLub_SI_7694_ScalaConcurrent() { + suppressingFailureBefore2_10_3 { + eval( s""" + object Test { + import _root_.scala.async.run.toughtype._ + import _root_.scala.async.Async.{async, await} + import scala.concurrent._ + import scala.concurrent.ExecutionContext.Implicits.global + async { + (if (true) await(null: Future[L[A, A]]) else await(null: Future[L[B, B]])) + } + } + """, compileOptions = s"-cp ${toolboxClasspath} ") + } + } + + private def suppressingFailureBefore2_10_3(body: => Any) { + try { + body + } catch { + case x: Throwable => + // @uncheckedBounds was only introduced in 2.10.3/ 2.11.0-M5, so avoid reporting this test failure in those cases. + scala.util.Properties.versionNumberString match { + case "2.10.0" | "2.10.1" | "2.10.2" | "2.11.0-M4" => // ignore, the @uncheckedBounds doesn't exist yet + case _ => + val annotationExists = + reflect.runtime.currentMirror.staticClass("scala.reflect.internal.annotations.uncheckedBounds") == reflect.runtime.universe.NoSymbol + if (annotationExists) + Assert.fail("@uncheckedBounds not found in scala-reflect.jar") + else + Assert.fail(s"@uncheckedBounds exists, but it didn't prevent this failure: $x") + } + } + } +} |