diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-12-19 04:23:15 -0800 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-12-19 04:23:15 -0800 |
commit | a86730d545d7afd6836755f90f5eb9d2443e70d2 (patch) | |
tree | 883d7c5e3678064a95a2a2c282152b2f00aed04d /src | |
parent | 933657d7541bc54b9f03d86d215d274f83694b31 (diff) | |
parent | 1359c3cd1750a95260bcc9d0dcbe56ef10c35d68 (diff) | |
download | scala-async-a86730d545d7afd6836755f90f5eb9d2443e70d2.tar.gz scala-async-a86730d545d7afd6836755f90f5eb9d2443e70d2.tar.bz2 scala-async-a86730d545d7afd6836755f90f5eb9d2443e70d2.zip |
Merge pull request #49 from phaller/topic/patmat-partial-function
Topic/patmat partial function
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 3 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 33 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 56 | ||||
-rw-r--r-- | src/test/scala/scala/async/TreeInterrogation.scala | 20 | ||||
-rw-r--r-- | src/test/scala/scala/async/neg/NakedAwait.scala | 22 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/toughtype/ToughType.scala | 31 |
6 files changed, 132 insertions, 33 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index fd28fb8..6ad1441 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -75,7 +75,8 @@ abstract class AsyncBase { // - if/match only used in statement position. val anfTree: Block = { val anf = AnfTransform[c.type](c) - val stats1 :+ expr1 = anf(body.tree) + val restored = utils.restorePatternMatchingFunctions(body.tree) + val stats1 :+ expr1 = anf(restored) val block = Block(stats1, expr1) c.typeCheck(block).asInstanceOf[Block] } diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 8bb5bcd..7aca269 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -8,6 +8,7 @@ import scala.reflect.macros.Context import scala.collection.mutable private[async] final case class AsyncAnalysis[C <: Context](c: C) { + import c.universe._ val utils = TransformUtils[c.type](c) @@ -67,15 +68,21 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { reportUnsupportedAwait(function, "nested function") } + override def patMatFunction(tree: Match) { + reportUnsupportedAwait(tree, "nested function") + } + override def traverse(tree: Tree) { def containsAwait = tree exists isAwait tree match { - case Try(_, _, _) if containsAwait => + case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + c.abort(tree.pos, "lazy vals are illegal within an async block") + case _ => super.traverse(tree) } } @@ -107,19 +114,19 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def nestedMethod(defDef: DefDef) { nestedMethodsToLift += defDef - defDef.rhs foreach { - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) => - valDefsToLift += vd // lift all vals referred to by nested methods. - case _ => - } - case _ => - } + markReferencedVals(defDef) } override def function(function: Function) { - function foreach { + markReferencedVals(function) + } + + override def patMatFunction(tree: Match) { + markReferencedVals(tree) + } + + private def markReferencedVals(tree: Tree) { + tree foreach { case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) => diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index f780799..5a4984f 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -25,12 +25,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val stateMachine = newTermName(fresh("stateMachine")) val stateMachineT = stateMachine.toTypeName val apply = newTermName("apply") + val applyOrElse = newTermName("applyOrElse") val tr = newTermName("tr") val matchRes = "matchres" val ifRes = "ifres" val await = "await" val bindSuffix = "$bind" - def arg(i: Int) = "arg" + i + + def arg(i: Int) = "arg" + i def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) @@ -64,7 +66,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { /** Descends into the regions of the tree that are subject to the * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descend stops, + * function, or by-name argument is encountered, the descent stops, * and `nestedClass` etc are invoked. */ trait AsyncTraverser extends Traverser { @@ -83,20 +85,24 @@ private[async] final case class TransformUtils[C <: Context](c: C) { def function(function: Function) { } + def patMatFunction(tree: Match) { + } + override def traverse(tree: Tree) { tree match { - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case Apply(fun, args) => + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case Apply(fun, args) => val isInByName = isByName(fun) for ((arg, index) <- args.zipWithIndex) { if (!isInByName(index)) traverse(arg) else byNameArgument(arg) } traverse(fun) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } } @@ -245,6 +251,40 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } + /** + * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }` + * with `Match(EmptyTree, cases`. + * + * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees. + */ + // TODO Reference JIRA issue. + final def restorePatternMatchingFunctions(tree: Tree) = + RestorePatternMatchingFunctions transform tree + + private object RestorePatternMatchingFunctions extends Transformer { + + import language.existentials + + override def transform(tree: Tree): Tree = { + val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet] + def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC + + tree match { + case Block( + (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil, + Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) => + val restored = (body collectFirst { + case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) => + val transformedCases = super.transformStats(cases, currentOwner).asInstanceOf[List[CaseDef]] + Match(EmptyTree, transformedCases) + }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body")) + restored + case t => super.transform(t) + } + } + } + + def isSafeToInline(tree: Tree) = { val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] object treeInfo extends { diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index b22faa9..93cfdf5 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -62,23 +62,21 @@ object TreeInterrogation extends App { val levels = Seq("trace", "debug") def setAll(value: Boolean) = levels.foreach(set(_, value)) - setAll(true) - try t finally setAll(false) + setAll(value = true) + try t finally setAll(value = false) } withDebug { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") + val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten") val tree = tb.parse( """ import scala.async.AsyncId.{async, await} - | def foo(a: Int, b: Int) = (a, b) - | val result = async { - | var i = 0 - | def next() = { - | i += 1; - | i - | } - | foo(next(), await(next())) + | async { + | await(1) + | val neg1 = -1 + | val a = await(1) + | val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int] + | await(f(2)) | } | () | """.stripMargin) diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index ecc84f9..c3537ec 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -93,6 +93,16 @@ class NakedAwait { } @Test + def nestedPatMatFunction() { + expectError("await must not be used under a nested class.") { // TODO more specific error message + """ + | import _root_.scala.async.AsyncId._ + | async { { case x => { await(false) } } : PartialFunction[Any, Any] } + """.stripMargin + } + } + + @Test def tryBody() { expectError("await must not be used under a try/catch.") { """ @@ -143,4 +153,16 @@ class NakedAwait { |""".stripMargin } } + + @Test + def lazyValIllegal() { + expectError("lazy vals are illegal") { + """ + | import _root_.scala.async.AsyncId._ + | def foo(): Any = async { val x = { lazy val y = 0; y } } + | () + | + |""".stripMargin + } + } } diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala index 9cfc1ca..83f5a2d 100644 --- a/src/test/scala/scala/async/run/toughtype/ToughType.scala +++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala @@ -36,4 +36,35 @@ class ToughTypeSpec { val res: (List[_], scala.async.run.toughtype.ToughTypeObject.Inner) = Await.result(fut, 2 seconds) res._1 mustBe (Nil) } + + @Test def patternMatchingPartialFunction() { + import AsyncId.{await, async} + async { + await(1) + val a = await(1) + val f = { case x => x + a }: PartialFunction[Int, Int] + await(f(2)) + } mustBe 3 + } + + @Test def patternMatchingPartialFunctionNested() { + import AsyncId.{await, async} + async { + await(1) + val neg1 = -1 + val a = await(1) + val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int] + await(f(2)) + } mustBe -3 + } + + @Test def patternMatchingFunction() { + import AsyncId.{await, async} + async { + await(1) + val a = await(1) + val f = { case x => x + a }: Function[Int, Int] + await(f(2)) + } mustBe 3 + } } |