Tuesday, August 18, 2009

Scala Tail Recursion

If you are into Scala, then you eventually may have heard about tail recursion or especially tail call optimization. This post sheds some light on what tail recursion is and on Scala's ability to optimize tail recursive functions.

Tail recursive functions
In Scala it is common to write in a functional style, and recursion is a technique frequently used in functional programming. A recursive function is a function that calls itself (directly or indirectly). Let's take for example the following factorial function:

def rfak(n: Int): Int = {
if (n == 1) n
else n * rfak(n-1)
}

This function calls itself recursively until its parameter n is down to one. The recursive calculation runs like this: rfak(3) = 3*(rfak(2)) = 3*(2*rfak(1)) = 3*(2*(1)) = 6.

Now, the function as shown above is a recursive function, but it is not tail recursive. In a tail recursive function the recursive call is the last action in the execution of the function's body. In the expression n * rfak(n-1), though, first the recursive call is made, and then (as the last action) the result of that call is multiplied by n. So the function is not tail recursive, because the recursive call is not the last action. But we can transform the implementation above into a tail recursive function:

def ifak(n: Int, acc: Int):Int = {
if (n == 1) acc
else ifak(n-1, n*acc)
}

We introduced a second parameter acc. n is like before, but acc (the accumulator) holds intermediate results of the calculation. In this case, in the else part, the recursive call is the last action in the execution of the body, so the function is tail recursive. The calculation process goes like this: ifak(3,1) = ifak(2,3) = ifak(1,6) = 6.

The benefit of tail recursion
So, why would one care if a recursive function is even tail recursive?

If you have a look at how the calculation goes on, again, you will notice, that in the non tail recursive way, the process has to "remember" intermediate results, and that (in consequence) the recursive calls have to be "stacked" until all recursive calls have been made. Then the final result can be calculated: rfak(3) = 3*(rfak(2)) = 3*(2*rfak(1)) = 3*(2*(1)) = 6.

This is not the case with tail recursive functions, as you can see in the process of the second calculation: ifak(3,1) = ifak(2,3) = ifak(1,6) = 6. There are no intermediate results (except for the accumulator argument) and the recursive calls are not stacked (This is why the process of a tail recursive function is also called iterative). In consequence, tail recursive functions are far less expensive regarding memory consumption, at least in theory.

Tail recursive functions in Java
Unfortunately, the JVM does not optimize tail calls itself. For example, if we had implemented the tail recursive factorial function in Java,

public static int ifak(int n, int acc) {
if (n == 1) return acc;
else return ifak(n - 1, acc * n);
}

then for each recursive call a new stack frame is built, which will slow down the calculation and will eventually lead to a StackOverflowError for reasonably large numbers of n. You can also watch this behavior in a debugger, or if you interrupt the calculation, for example by throwing an exception.

public static int ifak(int n, int acc) {
if (n == 1) return acc;
if (n == 10) throw new RuntimeException();
else return ifak(n - 1, acc * n);
}

Then the stacktrace will look like this, showing that the function is called again and again:

Exception in thread "main" java.lang.RuntimeException
at FakTestJava.ifak(FakTestJava.java:32)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:33)
at FakTestJava.ifak(FakTestJava.java:27)
at FakTestJava.main(FakTestJava.java:15)

The disassembled byte code instructions, generated by javac, proves this:

public static int ifak(int, int);
Code:
0: iload_0
1: iconst_1
2: if_icmpne 7
5: iload_1
6: ireturn
7: iload_0
8: iconst_1
9: isub
10: iload_1
11: iload_0
12: imul
13: invokestatic #6; //Method ifak:(II)I
16: ireturn

The recursive call is the second last instruction.

Tail call optimization in Scala
Now let's have a look at the byte codes the Scala compiler generated from the code above. Here's the tail recursive code again:

def ifak(n: Int, acc: Int):Int = {
if (n == 1) acc
else ifak(n-1, n*acc)
}

And here are the instructions that comprise the byte codes:

public int ifak(int, int);
Code:
0: iload_1
1: iconst_1
2: if_icmpne 7
5: iload_2
6: ireturn
7: iload_1
8: iconst_1
9: isub
10: iload_1
11: iload_2
12: imul
13: istore_2
14: istore_1
15: goto 0

As you can see, there are no recursive calls anymore. Instead there is a goto as the last instruction, which makes the process simply jump back to the beginning of the function. So, in this case the execution really runs in an iterative process, which is far more efficient than recursively calling the function again, while building new stack frames.

In fact, this is a feature of the Scala compiler called tail call optimization. It optimizes away the recursive call. This feature works only in simple cases as above, though. If the recursion is indirect, for example, Scala cannot optimize tail calls, because of the limited JVM instruction set.


That's essentially it, but if you are interested, let's go and fiddle some more with it. We could do the same as in the Java example above to prove, that there are no additional recursive calls. For example, if you'd throw a runtime exception in the process like this:

def ifak(n: Int, acc: Int):Int = {
if (n == 1) acc
if (n == 2) throw new RuntimeException()
else ifak(n-1, n*acc)
}

and call ifak(10,1) the stacktrace will also tell you that there are no additional stackframes built:

Exception in thread "main" java.lang.RuntimeException
at scalaapplication1.FakTest$.ifak(FakTest.scala:32)
at scalaapplication1.FakTest$.main(FakTest.scala:16)
at scalaapplication1.FakTest.main(FakTest.scala)

But you can also turn off the tail call optimization with a -g:notailcalls flag to the scala compiler. If you run the example again, then the stacktrace will look like this:

java.lang.RuntimeException
at scalaapplication1.FakTest$.ifak(FakTest.scala:32)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.ifak(FakTest.scala:33)
at scalaapplication1.FakTest$.main(FakTest.scala:16)
at scalaapplication1.FakTest.main(FakTest.scala)

Finally, let's have a look at how the code looks like when we decompile the class file using a Java decompiler like jd (http://java.decompiler.free.fr/). Here is the code once again:

def ifak(n: Int, acc: Int):Int = {
if (n == 1) acc
else ifak(n-1, n*acc)
}

And here is what jd decompiles from the class file:

public int ifak(int n, int acc) {
while (true) {
if (n == 1) return acc;
acc = n * acc;
n -= 1;
}
}

The tail recursive calls are transformed into an equivalent while loop!

After all, this should not be too surprising.

Resources
[1] Tail calls in the VM



5 comments:

Henry Ho said...

Wow.. one would have thought Java JVM would do tail recursive optimization by now...

Nick Wiedenbrueck said...

Added a link to a proposal for tail call optimization on the JVM. See Resources section.

Leonel said...

Does it really make sense to expect tail call optimizations from the JVM ?

After all, the JVM, although virtual, is a *machine*. Tail call optimizations are usually handled at a higher level by the compiler, like the scala compiler does, not by a CPU.

It's not only more usual, it's also easier for the compiler to output optimized bytecode than for the JVM to make this optimization on the fly.

RD1 said...

@Leonel: The JVM bytecode is at a relatively high level with a sophisticated security model, so compilers can't do this "low-level" optimization. There are various workarounds that work well for certain kinds of things, but general tail calls would require changes to the JVM bytecode.

aunndroid said...

Nick,

Thanks for the blog. Here is my own recursion blog. http://blog.aunndroid.com/2011/11/learning-scala-recursion-and-tco-1.html