Before we introduce continuations we need to build some infrastructure.
Below is a trampoline that operates on Iteration
An iteration is a computation that can either Yield
a new value or it can be Done
sealed trait Iteration[+R]
case class Yield[+R](result: R, next: () => Iteration[R]) extends Iteration[R]
case object Done extends Iteration[Nothing]
def trampoline[R](body: => Iteration[R]): Iterator[R] = {
def loop(thunk: () => Iteration[R]): Stream[R] = {
thunk.apply match {
case Yield(result, next) => Stream.cons(result, loop(next))
case Done => Stream.empty
loop(() => body).iterator
The trampoline uses an internal loop that turns the sequence of Iteration
objects into a Stream
We then get an Iterator
by calling iterator
on the resulting stream object.
By using a Stream
our evaluation is lazy; we don’t evaluate our next iteration until it is needed.
The trampoline can be used to build an iterator directly.
val itr1 = trampoline {
Yield(1, () => Yield(2, () => Yield(3, () => Done)))
for (i <- itr1) { println(i) }
That’s pretty horrible to write, so let’s use delimited continuations to create our Iteration
objects automatically.
We use the shift
and reset
operators to break the computation up into Iteration
then use trampoline
to turn the Iteration
s into an Iterator
import scala.continuations._
import scala.continuations.ControlContext.{shift,reset}
def iterator[R](body: => Unit @cps[Iteration[R],Iteration[R]]): Iterator[R] =
trampoline {
reset[Iteration[R],Iteration[R]] { body ; Done }
def yld[R](result: R): Unit @cps[Iteration[R],Iteration[R]] =
shift((k: Unit => Iteration[R]) => Yield(result, () => k(())))
Now we can rewrite our example.
val itr2 = iterator[Int] {
for (i <- itr2) { println(i) }
Much better!
Now here’s an example from the C# reference page for yield
that shows some more advanced usage.
The types can be a bit tricky to get used to, but it all works.
def power(number: Int, exponent: Int): Iterator[Int] = iterator[Int] {
def loop(result: Int, counter: Int): Unit @cps[Iteration[Int],Iteration[Int]] = {
if (counter < exponent) {
loop(result * number, counter + 1)
loop(number, 0)
for (i <- power(2, 8)) { println(i) }