summaryrefslogtreecommitdiff
path: root/src/library
diff options
context:
space:
mode:
authorGrzegorz Kossakowski <grzegorz.kossakowski@gmail.com>2013-09-07 21:25:54 -0700
committerGrzegorz Kossakowski <grzegorz.kossakowski@gmail.com>2013-09-07 21:25:54 -0700
commit45a7da1426586e9fe2abe19b41a296e76ef1d78a (patch)
treee981446eb8e2ccabb2cae315972e773802a3664d /src/library
parentc55e32d71d6c74c9280fb19c336969412c9857ec (diff)
parent4ddff662995a0abaf5b1eb872c1233b60e9eb3c0 (diff)
downloadscala-45a7da1426586e9fe2abe19b41a296e76ef1d78a.tar.gz
scala-45a7da1426586e9fe2abe19b41a296e76ef1d78a.tar.bz2
scala-45a7da1426586e9fe2abe19b41a296e76ef1d78a.zip
Merge pull request #2865 from folone/trampolines
Alter TailRec to have map and flatMap
Diffstat (limited to 'src/library')
-rw-r--r--src/library/scala/util/control/TailCalls.scala63
1 files changed, 54 insertions, 9 deletions
diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala
index ba3044c718..c3e7d98073 100644
--- a/src/library/scala/util/control/TailCalls.scala
+++ b/src/library/scala/util/control/TailCalls.scala
@@ -13,7 +13,11 @@ package util.control
* Tail calling methods have to return their result using `done` or call the
* next method using `tailcall`. Both return a `TailRec` object. The result
* of evaluating a tailcalling function can be retrieved from a `Tailrec`
- * value using method `result`. Here's a usage example:
+ * value using method `result`.
+ * Implemented as described in "Stackless Scala with Free Monads"
+ * http://blog.higher-order.com/assets/trampolines.pdf
+ *
+ * Here's a usage example:
* {{{
* import scala.util.control.TailCalls._
*
@@ -24,6 +28,14 @@ package util.control
* if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))
*
* isEven((1 to 100000).toList).result
+ *
+ * def fib(n: Int): TailRec[Int] =
+ * if (n < 2) done(n) else for {
+ * x <- tailcall(fib(n - 1))
+ * y <- tailcall(fib(n - 2))
+ * } yield (x + y)
+ *
+ * fib(40).result
* }}}
*/
object TailCalls {
@@ -31,14 +43,43 @@ object TailCalls {
/** This class represents a tailcalling computation
*/
abstract class TailRec[+A] {
+
+ /** Continue the computation with `f`. */
+ final def map[B](f: A => B): TailRec[B] =
+ flatMap(a => Call(() => Done(f(a))))
+
+ /** Continue the computation with `f` and merge the trampolining
+ * of this computation with that of `f`. */
+ final def flatMap[B](f: A => TailRec[B]): TailRec[B] =
+ this match {
+ case Done(a) => Call(() => f(a))
+ case c@Call(_) => Cont(c, f)
+ // Take advantage of the monad associative law to optimize the size of the required stack
+ case Cont(s, g) => Cont(s, (x:Any) => g(x).flatMap(f))
+ }
+
+ /** Returns either the next step of the tailcalling computation,
+ * or the result if there are no more steps. */
+ @annotation.tailrec final def resume: Either[() => TailRec[A], A] = this match {
+ case Done(a) => Right(a)
+ case Call(k) => Left(k)
+ case Cont(a, f) => a match {
+ case Done(v) => f(v).resume
+ case Call(k) => Left(() => k().flatMap(f))
+ case Cont(b, g) => b.flatMap(x => g(x) flatMap f).resume
+ }
+ }
+
/** Returns the result of the tailcalling computation.
*/
- def result: A = {
- def loop(body: TailRec[A]): A = body match {
- case Call(rest) => loop(rest())
- case Done(result) => result
+ @annotation.tailrec final def result: A = this match {
+ case Done(a) => a
+ case Call(t) => t().result
+ case Cont(a, f) => a match {
+ case Done(v) => f(v).result
+ case Call(t) => t().flatMap(f).result
+ case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result
}
- loop(this)
}
}
@@ -47,19 +88,23 @@ object TailCalls {
/** Internal class representing the final result returned from a tailcalling
* computation */
- protected case class Done[A](override val result: A) extends TailRec[A]
+ protected case class Done[A](value: A) extends TailRec[A]
+
+ /** Internal class representing a continuation with function A => TailRec[B].
+ * It is needed for the flatMap to be implemented. */
+ protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B]
/** Performs a tailcall
* @param rest the expression to be evaluated in the tailcall
* @return a `TailRec` object representing the expression `rest`
*/
- def tailcall[A](rest: => TailRec[A]): TailRec[A] = new Call(() => rest)
+ def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)
/** Used to return final result from tailcalling computation
* @param `result` the result value
* @return a `TailRec` object representing a computation which immediately
* returns `result`
*/
- def done[A](result: A): TailRec[A] = new Done(result)
+ def done[A](result: A): TailRec[A] = Done(result)
}