From 4ddff662995a0abaf5b1eb872c1233b60e9eb3c0 Mon Sep 17 00:00:00 2001 From: Runar Bjarnason Date: Fri, 23 Aug 2013 16:23:56 -0400 Subject: Stackless implementation of TailRec in constant memory. --- src/library/scala/util/control/TailCalls.scala | 58 +++++++++++++++----------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala index 3c87b95a53..c3e7d98073 100644 --- a/src/library/scala/util/control/TailCalls.scala +++ b/src/library/scala/util/control/TailCalls.scala @@ -9,8 +9,6 @@ package scala package util.control -import collection.mutable.ArrayStack - /** Methods exported by this object implement tail calls via trampolining. * Tail calling methods have to return their result using `done` or call the * next method using `tailcall`. Both return a `TailRec` object. The result @@ -45,31 +43,43 @@ object TailCalls { /** This class represents a tailcalling computation */ abstract class TailRec[+A] { - def map[B](f: A => B): TailRec[B] = + + /** Continue the computation with `f`. */ + final def map[B](f: A => B): TailRec[B] = flatMap(a => Call(() => Done(f(a)))) - def flatMap[B](f: A => TailRec[B]): TailRec[B] = - Cont(this, f) + + /** Continue the computation with `f` and merge the trampolining + * of this computation with that of `f`. */ + final def flatMap[B](f: A => TailRec[B]): TailRec[B] = + this match { + case Done(a) => Call(() => f(a)) + case c@Call(_) => Cont(c, f) + // Take advantage of the monad associative law to optimize the size of the required stack + case Cont(s, g) => Cont(s, (x:Any) => g(x).flatMap(f)) + } + + /** Returns either the next step of the tailcalling computation, + * or the result if there are no more steps. */ + @annotation.tailrec final def resume: Either[() => TailRec[A], A] = this match { + case Done(a) => Right(a) + case Call(k) => Left(k) + case Cont(a, f) => a match { + case Done(v) => f(v).resume + case Call(k) => Left(() => k().flatMap(f)) + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).resume + } + } + /** Returns the result of the tailcalling computation. */ - def result: A = { - var cur: TailRec[_] = this - val stack: ArrayStack[Any => TailRec[A]] = new ArrayStack() - var result: A = null.asInstanceOf[A] - while (result == null) { - cur match { - case Done(a) => - if(!stack.isEmpty) { - val fun = stack.pop - cur = fun(a) - } else result = a.asInstanceOf[A] - case Call(t) => cur = t() - case Cont(a, f) => { - cur = a - stack.push(f.asInstanceOf[Any => TailRec[A]]) - } - } + @annotation.tailrec final def result: A = this match { + case Done(a) => a + case Call(t) => t().result + case Cont(a, f) => a match { + case Done(v) => f(v).result + case Call(t) => t().flatMap(f).result + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result } - result } } @@ -78,7 +88,7 @@ object TailCalls { /** Internal class representing the final result returned from a tailcalling * computation */ - protected case class Done[A](override val result: A) extends TailRec[A] + protected case class Done[A](value: A) extends TailRec[A] /** Internal class representing a continuation with function A => TailRec[B]. * It is needed for the flatMap to be implemented. */ -- cgit v1.2.3