In this tutorial I would like to talk about an interesting concept of functional programming - tail recursion, and particularly about tail recursion in Scala.
What is recursion in general? The definition says, that recursion is the process of determining an object in terms of itself. For instance, a mathematical expression can contain numeric operands, arithmetic functions and other math expressions.
Recursion
In our work we deal with recursive functions. To compute the values of this function for a certain set of inputs, it’s necessary to evaluate it for another set of input data. In other words, this is a function which calls itself. The most known examples of recursion are factorial (as the majority of programming books say) and Fibonacci numbers. Let’s forget about factorial and see how the implementation of Fibonacci numbers looks in Java:
public static int fib(int n) {
if (n > 1) return fib(n - 1) + fib(n - 2);
else return n;
}
As we can notice from the example, before returning the value from the function, two calls of the same function occur but with other values.
Tail recursion
Now, we get back to the tail recursion. It will be tail if the call of recursion function is the last operation before the return of function. Let's try to rewrite the previous example by using the tail recursion:
public static int fibWithTailRec(int n) {
if (n > 1) return fibIter(1, 1, n - 2);
else return n;
}
private static int fibIter(int prev, int current, int n) {
if (n == 0) return current;
else return fibIter(current, prev + current, n - 1);
}
API of the class hasn't changed: we still have a method that returns the value of the n-th Fibonacci number. We also have a new private method fibIter
, which contains tail recursion. Very often, to implement tail recursion, a helper method is used. This method takes a recursive state before the next iteration. Due to this, tail recursive function can be represented as a loop.
public static int fibWithLoop(int n) {
if (n <= 1) return n;
int i = 2;
int prev = 1;
int current = 1;
while (i < n) {
int next = prev + current;
prev = current;
current = next;
i += 1;
}
return current;
}
This code looks awful and there are many places where you can make a mistake. That's why, the recursive solutions are often more efficient due to their brevity and readability.
Let's check our functions under a pressure and compute the 10000th element of the sequence. Unfortunately, both methods - fib
and fibWithTailRec
- will throw java.lang.StackOverflowError
. This means that recursions in Java should be used carefully. And don’t forget about possibility to use the loop :)
Scala rescue ranger
What about recursion in other JVM languages? Let’s see how Scala works with recursive functions. At first, we will take a look at some simple Scala recursion examples and rewrite both recursive code implementations which calculate the Fibonacci numbers. We will try to code similar implementation of the same functions:
def fib(n: Int): Int =
if (n > 1) fib(n - 1) + fib(n - 2)
else n
def fibWithTailRec(n: Int): Int =
if (n > 1) fibIter(1, 1, n - 2)
else n
// Function fibIter is not inner for illustrative purpose.
private def fibIter(prev: Int, current: Int, n: Int): Int =
if (n == 0) current
else fibIter(current, prev + current, n - 1)
Now, we will calculate the 10000th Fibonacci number once again but this time on Scala. It turns out, that the tail-recursive version has successfully coped with the task. What is the difference between almost "the same" code with tail recursion in Java and Scala? Let's look at the bytecode of these methods (javap -c
) and check what the compilers actually do:
java version:
public static int fibIter(int, int, int);
Code:
0: iload_2
1: ifne 6 // check the end condition
4: iload_1
5: ireturn
6: iload_1 // push current to stack
7: iload_0
8: iload_1
9: iadd // push prev + current to stack
10: iload_2
11: iconst_1
12: isub // push n - 1 to stack
13: invokestatic #2 // call method fibIter:(III)I
16: ireturn // return the result of call
scala version:
public int fibIter(int, int, int);
Code:
0: iload_3
1: iconst_0
2: if_icmpne 7 // check the end condition
5: iload_2
6: ireturn
7: iload_2 // push current to stack
8: iload_1
9: iload_2
10: iadd // push prev + current to stack
11: iload_3
12: iconst_1
13: isub // push n - 1 to stack
14: istore_3
15: istore_2
16: istore_1 // save everything in the local variables
17: goto 0 // go to the top of the iteration
The Scala compiler understood that it was the tail recursion and transformed a recursive call of function into the loop, which, as we could see earlier, is not that easy to write by yourself.
The world will never be same again
So, we have grasped the benefits of the Scala tail recursion. It's time to use this knowledge! In the world of functional programming, no program can exist without lists. Let's try to implement a fold
- one of the basic functions with lists (Scala offers us two versions - left and right):
def foldLeft[A, B](seq: Seq[A], z: B)(f: (B, A) => B): B =
seq match {
case Nil => z
case x :: xs => foldLeft(xs, f(z, x))(f)
}
def foldRight[A, B](seq: Seq[A], z: B)(f: (A, B) => B): B =
seq match {
case Nil => z
case x :: xs => f(x, foldRight(xs, z)(f))
}
As we can see, both functions are recursive. But what a pity - foldRight
has no tail recursion and therefore can cause a stack overflow. There is a quite simple solution, which certainly requires some efforts, but at the same time doesn’t rise the algorithm complexity. We will write the foldRight
function via foldLeft
:
def foldRightUsingFoldLeft[A, B](seq: Seq[A], z: B)(f: (A, B) => B): B =
foldLeft(seq.reverse, z)((b, a) => f(a, b))
To determine whether a function is optimized, the @tailrec
annotation can be used. Some developers mistakenly believe that this command instructs the compiler to perform tail recursion in a method, but this is not entirely accurate (and not always possible). Instead, the annotation checks if the annotated function contains tail recursion, and if it does not, a compilation error is triggered. Several reasons for this can be easily identified, including:
- if the method can be overridden in subclasses because it is not a final method,
- if the method does not call itself,
- or if it is not a tail recursion.
By following these guidelines, developers can optimize their code and enjoy programming in Scala.