diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/library/scala/util/control/TailCalls.scala | 63 |
1 files changed, 54 insertions, 9 deletions
diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala index ba3044c718..c3e7d98073 100644 --- a/src/library/scala/util/control/TailCalls.scala +++ b/src/library/scala/util/control/TailCalls.scala @@ -13,7 +13,11 @@ package util.control * Tail calling methods have to return their result using `done` or call the * next method using `tailcall`. Both return a `TailRec` object. The result * of evaluating a tailcalling function can be retrieved from a `Tailrec` - * value using method `result`. Here's a usage example: + * value using method `result`. + * Implemented as described in "Stackless Scala with Free Monads" + * http://blog.higher-order.com/assets/trampolines.pdf + * + * Here's a usage example: * {{{ * import scala.util.control.TailCalls._ * @@ -24,6 +28,14 @@ package util.control * if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) * * isEven((1 to 100000).toList).result + * + * def fib(n: Int): TailRec[Int] = + * if (n < 2) done(n) else for { + * x <- tailcall(fib(n - 1)) + * y <- tailcall(fib(n - 2)) + * } yield (x + y) + * + * fib(40).result * }}} */ object TailCalls { @@ -31,14 +43,43 @@ object TailCalls { /** This class represents a tailcalling computation */ abstract class TailRec[+A] { + + /** Continue the computation with `f`. */ + final def map[B](f: A => B): TailRec[B] = + flatMap(a => Call(() => Done(f(a)))) + + /** Continue the computation with `f` and merge the trampolining + * of this computation with that of `f`. */ + final def flatMap[B](f: A => TailRec[B]): TailRec[B] = + this match { + case Done(a) => Call(() => f(a)) + case c@Call(_) => Cont(c, f) + // Take advantage of the monad associative law to optimize the size of the required stack + case Cont(s, g) => Cont(s, (x:Any) => g(x).flatMap(f)) + } + + /** Returns either the next step of the tailcalling computation, + * or the result if there are no more steps. */ + @annotation.tailrec final def resume: Either[() => TailRec[A], A] = this match { + case Done(a) => Right(a) + case Call(k) => Left(k) + case Cont(a, f) => a match { + case Done(v) => f(v).resume + case Call(k) => Left(() => k().flatMap(f)) + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).resume + } + } + /** Returns the result of the tailcalling computation. */ - def result: A = { - def loop(body: TailRec[A]): A = body match { - case Call(rest) => loop(rest()) - case Done(result) => result + @annotation.tailrec final def result: A = this match { + case Done(a) => a + case Call(t) => t().result + case Cont(a, f) => a match { + case Done(v) => f(v).result + case Call(t) => t().flatMap(f).result + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result } - loop(this) } } @@ -47,19 +88,23 @@ object TailCalls { /** Internal class representing the final result returned from a tailcalling * computation */ - protected case class Done[A](override val result: A) extends TailRec[A] + protected case class Done[A](value: A) extends TailRec[A] + + /** Internal class representing a continuation with function A => TailRec[B]. + * It is needed for the flatMap to be implemented. */ + protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B] /** Performs a tailcall * @param rest the expression to be evaluated in the tailcall * @return a `TailRec` object representing the expression `rest` */ - def tailcall[A](rest: => TailRec[A]): TailRec[A] = new Call(() => rest) + def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest) /** Used to return final result from tailcalling computation * @param `result` the result value * @return a `TailRec` object representing a computation which immediately * returns `result` */ - def done[A](result: A): TailRec[A] = new Done(result) + def done[A](result: A): TailRec[A] = Done(result) } |