From fdc543760d619e287f980d5cadb183993ff6004c Mon Sep 17 00:00:00 2001 From: George Leontiev Date: Mon, 19 Aug 2013 21:57:50 +0200 Subject: Alter TailRec to have map and flatMap As described in the "Stackless Scala with Free Monads" paper scala> import scala.util.control.TailCalls._ import scala.util.control.TailCalls._ scala> :paste // Entering paste mode (ctrl-D to finish) def isEven(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(true) else tailcall(isOdd(xs.tail)) def isOdd(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) // Exiting paste mode, now interpreting. isEven: (xs: List[Int])util.control.TailCalls.TailRec[Boolean] isOdd: (xs: List[Int])util.control.TailCalls.TailRec[Boolean] scala> isEven((1 to 100000).toList).result res0: Boolean = true scala> def fib(n: Int): TailRec[Int] = | if (n < 2) done(n) else for { | x <- tailcall(fib(n - 1)) | y <- tailcall(fib(n - 2)) | } yield (x + y) fib: (n: Int)util.control.TailCalls.TailRec[Int] scala> fib(40).result res1: Int = 102334155 --- src/library/scala/util/control/TailCalls.scala | 49 ++++++++++++++++++++++---- test/files/run/tailcalls.scala | 13 +++++++ 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala index ba3044c718..3c87b95a53 100644 --- a/src/library/scala/util/control/TailCalls.scala +++ b/src/library/scala/util/control/TailCalls.scala @@ -9,11 +9,17 @@ package scala package util.control +import collection.mutable.ArrayStack + /** Methods exported by this object implement tail calls via trampolining. * Tail calling methods have to return their result using `done` or call the * next method using `tailcall`. Both return a `TailRec` object. The result * of evaluating a tailcalling function can be retrieved from a `Tailrec` - * value using method `result`. Here's a usage example: + * value using method `result`. + * Implemented as described in "Stackless Scala with Free Monads" + * http://blog.higher-order.com/assets/trampolines.pdf + * + * Here's a usage example: * {{{ * import scala.util.control.TailCalls._ * @@ -24,6 +30,14 @@ package util.control * if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) * * isEven((1 to 100000).toList).result + * + * def fib(n: Int): TailRec[Int] = + * if (n < 2) done(n) else for { + * x <- tailcall(fib(n - 1)) + * y <- tailcall(fib(n - 2)) + * } yield (x + y) + * + * fib(40).result * }}} */ object TailCalls { @@ -31,14 +45,31 @@ object TailCalls { /** This class represents a tailcalling computation */ abstract class TailRec[+A] { + def map[B](f: A => B): TailRec[B] = + flatMap(a => Call(() => Done(f(a)))) + def flatMap[B](f: A => TailRec[B]): TailRec[B] = + Cont(this, f) /** Returns the result of the tailcalling computation. */ def result: A = { - def loop(body: TailRec[A]): A = body match { - case Call(rest) => loop(rest()) - case Done(result) => result + var cur: TailRec[_] = this + val stack: ArrayStack[Any => TailRec[A]] = new ArrayStack() + var result: A = null.asInstanceOf[A] + while (result == null) { + cur match { + case Done(a) => + if(!stack.isEmpty) { + val fun = stack.pop + cur = fun(a) + } else result = a.asInstanceOf[A] + case Call(t) => cur = t() + case Cont(a, f) => { + cur = a + stack.push(f.asInstanceOf[Any => TailRec[A]]) + } + } } - loop(this) + result } } @@ -49,17 +80,21 @@ object TailCalls { * computation */ protected case class Done[A](override val result: A) extends TailRec[A] + /** Internal class representing a continuation with function A => TailRec[B]. + * It is needed for the flatMap to be implemented. */ + protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B] + /** Performs a tailcall * @param rest the expression to be evaluated in the tailcall * @return a `TailRec` object representing the expression `rest` */ - def tailcall[A](rest: => TailRec[A]): TailRec[A] = new Call(() => rest) + def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest) /** Used to return final result from tailcalling computation * @param `result` the result value * @return a `TailRec` object representing a computation which immediately * returns `result` */ - def done[A](result: A): TailRec[A] = new Done(result) + def done[A](result: A): TailRec[A] = Done(result) } diff --git a/test/files/run/tailcalls.scala b/test/files/run/tailcalls.scala index 1d4124e138..e5d8891cc7 100644 --- a/test/files/run/tailcalls.scala +++ b/test/files/run/tailcalls.scala @@ -391,7 +391,20 @@ object Test { def isOdd(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) + def fib(n: Int): TailRec[Int] = + if (n < 2) done(n) else for { + x <- tailcall(fib(n - 1)) + y <- tailcall(fib(n - 2)) + } yield (x + y) + + def rec(n: Int): TailRec[Int] = + if (n == 1) done(n) else for { + x <- tailcall(rec(n - 1)) + } yield x + assert(isEven((1 to 100000).toList).result) + //assert(fib(40).result == 102334155) // Commented out, as it takes a long time + assert(rec(100000).result == 1) } -- cgit v1.2.3 From 4ddff662995a0abaf5b1eb872c1233b60e9eb3c0 Mon Sep 17 00:00:00 2001 From: Runar Bjarnason Date: Fri, 23 Aug 2013 16:23:56 -0400 Subject: Stackless implementation of TailRec in constant memory. --- src/library/scala/util/control/TailCalls.scala | 58 +++++++++++++++----------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala index 3c87b95a53..c3e7d98073 100644 --- a/src/library/scala/util/control/TailCalls.scala +++ b/src/library/scala/util/control/TailCalls.scala @@ -9,8 +9,6 @@ package scala package util.control -import collection.mutable.ArrayStack - /** Methods exported by this object implement tail calls via trampolining. * Tail calling methods have to return their result using `done` or call the * next method using `tailcall`. Both return a `TailRec` object. The result @@ -45,31 +43,43 @@ object TailCalls { /** This class represents a tailcalling computation */ abstract class TailRec[+A] { - def map[B](f: A => B): TailRec[B] = + + /** Continue the computation with `f`. */ + final def map[B](f: A => B): TailRec[B] = flatMap(a => Call(() => Done(f(a)))) - def flatMap[B](f: A => TailRec[B]): TailRec[B] = - Cont(this, f) + + /** Continue the computation with `f` and merge the trampolining + * of this computation with that of `f`. */ + final def flatMap[B](f: A => TailRec[B]): TailRec[B] = + this match { + case Done(a) => Call(() => f(a)) + case c@Call(_) => Cont(c, f) + // Take advantage of the monad associative law to optimize the size of the required stack + case Cont(s, g) => Cont(s, (x:Any) => g(x).flatMap(f)) + } + + /** Returns either the next step of the tailcalling computation, + * or the result if there are no more steps. */ + @annotation.tailrec final def resume: Either[() => TailRec[A], A] = this match { + case Done(a) => Right(a) + case Call(k) => Left(k) + case Cont(a, f) => a match { + case Done(v) => f(v).resume + case Call(k) => Left(() => k().flatMap(f)) + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).resume + } + } + /** Returns the result of the tailcalling computation. */ - def result: A = { - var cur: TailRec[_] = this - val stack: ArrayStack[Any => TailRec[A]] = new ArrayStack() - var result: A = null.asInstanceOf[A] - while (result == null) { - cur match { - case Done(a) => - if(!stack.isEmpty) { - val fun = stack.pop - cur = fun(a) - } else result = a.asInstanceOf[A] - case Call(t) => cur = t() - case Cont(a, f) => { - cur = a - stack.push(f.asInstanceOf[Any => TailRec[A]]) - } - } + @annotation.tailrec final def result: A = this match { + case Done(a) => a + case Call(t) => t().result + case Cont(a, f) => a match { + case Done(v) => f(v).result + case Call(t) => t().flatMap(f).result + case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result } - result } } @@ -78,7 +88,7 @@ object TailCalls { /** Internal class representing the final result returned from a tailcalling * computation */ - protected case class Done[A](override val result: A) extends TailRec[A] + protected case class Done[A](value: A) extends TailRec[A] /** Internal class representing a continuation with function A => TailRec[B]. * It is needed for the flatMap to be implemented. */ -- cgit v1.2.3