summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Leontiev <folone@gmail.com>2013-08-19 21:57:50 +0200
committerGeorge Leontiev <folone@gmail.com>2013-08-24 09:57:35 +0200
commitfdc543760d619e287f980d5cadb183993ff6004c (patch)
tree0e0624bef4521aa3645fe8f165b5876e73b034b3
parent168d27020b4b2b3c78365137e99ca1e85c90b01a (diff)
downloadscala-fdc543760d619e287f980d5cadb183993ff6004c.tar.gz
scala-fdc543760d619e287f980d5cadb183993ff6004c.tar.bz2
scala-fdc543760d619e287f980d5cadb183993ff6004c.zip
Alter TailRec to have map and flatMap
As described in the "Stackless Scala with Free Monads" paper scala> import scala.util.control.TailCalls._ import scala.util.control.TailCalls._ scala> :paste // Entering paste mode (ctrl-D to finish) def isEven(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(true) else tailcall(isOdd(xs.tail)) def isOdd(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) // Exiting paste mode, now interpreting. isEven: (xs: List[Int])util.control.TailCalls.TailRec[Boolean] isOdd: (xs: List[Int])util.control.TailCalls.TailRec[Boolean] scala> isEven((1 to 100000).toList).result res0: Boolean = true scala> def fib(n: Int): TailRec[Int] = | if (n < 2) done(n) else for { | x <- tailcall(fib(n - 1)) | y <- tailcall(fib(n - 2)) | } yield (x + y) fib: (n: Int)util.control.TailCalls.TailRec[Int] scala> fib(40).result res1: Int = 102334155
-rw-r--r--src/library/scala/util/control/TailCalls.scala49
-rw-r--r--test/files/run/tailcalls.scala13
2 files changed, 55 insertions, 7 deletions
diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala
index ba3044c718..3c87b95a53 100644
--- a/src/library/scala/util/control/TailCalls.scala
+++ b/src/library/scala/util/control/TailCalls.scala
@@ -9,11 +9,17 @@
package scala
package util.control
+import collection.mutable.ArrayStack
+
/** Methods exported by this object implement tail calls via trampolining.
* Tail calling methods have to return their result using `done` or call the
* next method using `tailcall`. Both return a `TailRec` object. The result
* of evaluating a tailcalling function can be retrieved from a `Tailrec`
- * value using method `result`. Here's a usage example:
+ * value using method `result`.
+ * Implemented as described in "Stackless Scala with Free Monads"
+ * http://blog.higher-order.com/assets/trampolines.pdf
+ *
+ * Here's a usage example:
* {{{
* import scala.util.control.TailCalls._
*
@@ -24,6 +30,14 @@ package util.control
* if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))
*
* isEven((1 to 100000).toList).result
+ *
+ * def fib(n: Int): TailRec[Int] =
+ * if (n < 2) done(n) else for {
+ * x <- tailcall(fib(n - 1))
+ * y <- tailcall(fib(n - 2))
+ * } yield (x + y)
+ *
+ * fib(40).result
* }}}
*/
object TailCalls {
@@ -31,14 +45,31 @@ object TailCalls {
/** This class represents a tailcalling computation
*/
abstract class TailRec[+A] {
+ def map[B](f: A => B): TailRec[B] =
+ flatMap(a => Call(() => Done(f(a))))
+ def flatMap[B](f: A => TailRec[B]): TailRec[B] =
+ Cont(this, f)
/** Returns the result of the tailcalling computation.
*/
def result: A = {
- def loop(body: TailRec[A]): A = body match {
- case Call(rest) => loop(rest())
- case Done(result) => result
+ var cur: TailRec[_] = this
+ val stack: ArrayStack[Any => TailRec[A]] = new ArrayStack()
+ var result: A = null.asInstanceOf[A]
+ while (result == null) {
+ cur match {
+ case Done(a) =>
+ if(!stack.isEmpty) {
+ val fun = stack.pop
+ cur = fun(a)
+ } else result = a.asInstanceOf[A]
+ case Call(t) => cur = t()
+ case Cont(a, f) => {
+ cur = a
+ stack.push(f.asInstanceOf[Any => TailRec[A]])
+ }
+ }
}
- loop(this)
+ result
}
}
@@ -49,17 +80,21 @@ object TailCalls {
* computation */
protected case class Done[A](override val result: A) extends TailRec[A]
+ /** Internal class representing a continuation with function A => TailRec[B].
+ * It is needed for the flatMap to be implemented. */
+ protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B]
+
/** Performs a tailcall
* @param rest the expression to be evaluated in the tailcall
* @return a `TailRec` object representing the expression `rest`
*/
- def tailcall[A](rest: => TailRec[A]): TailRec[A] = new Call(() => rest)
+ def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)
/** Used to return final result from tailcalling computation
* @param `result` the result value
* @return a `TailRec` object representing a computation which immediately
* returns `result`
*/
- def done[A](result: A): TailRec[A] = new Done(result)
+ def done[A](result: A): TailRec[A] = Done(result)
}
diff --git a/test/files/run/tailcalls.scala b/test/files/run/tailcalls.scala
index 1d4124e138..e5d8891cc7 100644
--- a/test/files/run/tailcalls.scala
+++ b/test/files/run/tailcalls.scala
@@ -391,7 +391,20 @@ object Test {
def isOdd(xs: List[Int]): TailRec[Boolean] =
if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))
+ def fib(n: Int): TailRec[Int] =
+ if (n < 2) done(n) else for {
+ x <- tailcall(fib(n - 1))
+ y <- tailcall(fib(n - 2))
+ } yield (x + y)
+
+ def rec(n: Int): TailRec[Int] =
+ if (n == 1) done(n) else for {
+ x <- tailcall(rec(n - 1))
+ } yield x
+
assert(isEven((1 to 100000).toList).result)
+ //assert(fib(40).result == 102334155) // Commented out, as it takes a long time
+ assert(rec(100000).result == 1)
}