diff options
author | George Leontiev <folone@gmail.com> | 2013-08-19 21:57:50 +0200 |
---|---|---|
committer | George Leontiev <folone@gmail.com> | 2013-08-24 09:57:35 +0200 |
commit | fdc543760d619e287f980d5cadb183993ff6004c (patch) | |
tree | 0e0624bef4521aa3645fe8f165b5876e73b034b3 | |
parent | 168d27020b4b2b3c78365137e99ca1e85c90b01a (diff) | |
download | scala-fdc543760d619e287f980d5cadb183993ff6004c.tar.gz scala-fdc543760d619e287f980d5cadb183993ff6004c.tar.bz2 scala-fdc543760d619e287f980d5cadb183993ff6004c.zip |
Alter TailRec to have map and flatMap
As described in the "Stackless Scala with Free Monads" paper
scala> import scala.util.control.TailCalls._
import scala.util.control.TailCalls._
scala> :paste
// Entering paste mode (ctrl-D to finish)
def isEven(xs: List[Int]): TailRec[Boolean] =
if (xs.isEmpty) done(true) else tailcall(isOdd(xs.tail))
def isOdd(xs: List[Int]): TailRec[Boolean] =
if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))
// Exiting paste mode, now interpreting.
isEven: (xs: List[Int])util.control.TailCalls.TailRec[Boolean]
isOdd: (xs: List[Int])util.control.TailCalls.TailRec[Boolean]
scala> isEven((1 to 100000).toList).result
res0: Boolean = true
scala> def fib(n: Int): TailRec[Int] =
| if (n < 2) done(n) else for {
| x <- tailcall(fib(n - 1))
| y <- tailcall(fib(n - 2))
| } yield (x + y)
fib: (n: Int)util.control.TailCalls.TailRec[Int]
scala> fib(40).result
res1: Int = 102334155
-rw-r--r-- | src/library/scala/util/control/TailCalls.scala | 49 | ||||
-rw-r--r-- | test/files/run/tailcalls.scala | 13 |
2 files changed, 55 insertions, 7 deletions
diff --git a/src/library/scala/util/control/TailCalls.scala b/src/library/scala/util/control/TailCalls.scala index ba3044c718..3c87b95a53 100644 --- a/src/library/scala/util/control/TailCalls.scala +++ b/src/library/scala/util/control/TailCalls.scala @@ -9,11 +9,17 @@ package scala package util.control +import collection.mutable.ArrayStack + /** Methods exported by this object implement tail calls via trampolining. * Tail calling methods have to return their result using `done` or call the * next method using `tailcall`. Both return a `TailRec` object. The result * of evaluating a tailcalling function can be retrieved from a `Tailrec` - * value using method `result`. Here's a usage example: + * value using method `result`. + * Implemented as described in "Stackless Scala with Free Monads" + * http://blog.higher-order.com/assets/trampolines.pdf + * + * Here's a usage example: * {{{ * import scala.util.control.TailCalls._ * @@ -24,6 +30,14 @@ package util.control * if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) * * isEven((1 to 100000).toList).result + * + * def fib(n: Int): TailRec[Int] = + * if (n < 2) done(n) else for { + * x <- tailcall(fib(n - 1)) + * y <- tailcall(fib(n - 2)) + * } yield (x + y) + * + * fib(40).result * }}} */ object TailCalls { @@ -31,14 +45,31 @@ object TailCalls { /** This class represents a tailcalling computation */ abstract class TailRec[+A] { + def map[B](f: A => B): TailRec[B] = + flatMap(a => Call(() => Done(f(a)))) + def flatMap[B](f: A => TailRec[B]): TailRec[B] = + Cont(this, f) /** Returns the result of the tailcalling computation. */ def result: A = { - def loop(body: TailRec[A]): A = body match { - case Call(rest) => loop(rest()) - case Done(result) => result + var cur: TailRec[_] = this + val stack: ArrayStack[Any => TailRec[A]] = new ArrayStack() + var result: A = null.asInstanceOf[A] + while (result == null) { + cur match { + case Done(a) => + if(!stack.isEmpty) { + val fun = stack.pop + cur = fun(a) + } else result = a.asInstanceOf[A] + case Call(t) => cur = t() + case Cont(a, f) => { + cur = a + stack.push(f.asInstanceOf[Any => TailRec[A]]) + } + } } - loop(this) + result } } @@ -49,17 +80,21 @@ object TailCalls { * computation */ protected case class Done[A](override val result: A) extends TailRec[A] + /** Internal class representing a continuation with function A => TailRec[B]. + * It is needed for the flatMap to be implemented. */ + protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B] + /** Performs a tailcall * @param rest the expression to be evaluated in the tailcall * @return a `TailRec` object representing the expression `rest` */ - def tailcall[A](rest: => TailRec[A]): TailRec[A] = new Call(() => rest) + def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest) /** Used to return final result from tailcalling computation * @param `result` the result value * @return a `TailRec` object representing a computation which immediately * returns `result` */ - def done[A](result: A): TailRec[A] = new Done(result) + def done[A](result: A): TailRec[A] = Done(result) } diff --git a/test/files/run/tailcalls.scala b/test/files/run/tailcalls.scala index 1d4124e138..e5d8891cc7 100644 --- a/test/files/run/tailcalls.scala +++ b/test/files/run/tailcalls.scala @@ -391,7 +391,20 @@ object Test { def isOdd(xs: List[Int]): TailRec[Boolean] = if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail)) + def fib(n: Int): TailRec[Int] = + if (n < 2) done(n) else for { + x <- tailcall(fib(n - 1)) + y <- tailcall(fib(n - 2)) + } yield (x + y) + + def rec(n: Int): TailRec[Int] = + if (n == 1) done(n) else for { + x <- tailcall(rec(n - 1)) + } yield x + assert(isEven((1 to 100000).toList).result) + //assert(fib(40).result == 102334155) // Commented out, as it takes a long time + assert(rec(100000).result == 1) } |