Why You Should Use Tail Recursion in Scala

In this tutorial I would like to talk about an interesting concept of functional programming - tail recursion, and particularly about tail recursion in Scala.

What is recursion in general? The definition says, that recursion is the process of determining an object in terms of itself. For instance, a mathematical expression can contain numeric operands, arithmetic functions and other math expressions.

Recursion

In our work we deal with recursive functions. To compute the values of this function for a certain set of inputs, it’s necessary to evaluate it for another set of input data. In other words, this is a function which calls itself. The most known examples of recursion are factorial (as the majority of programming books say) and Fibonacci numbers. Let’s forget about factorial and see how the implementation of Fibonacci numbers looks in Java:

    public static int fib(int n) {
        if (n > 1) return fib(n - 1) + fib(n - 2);
        else return n;
    }

As we can notice from the example, before returning the value from the function, two calls of the same function occur but with other values.

Tail recursion

Now, we get back to the tail recursion. It will be tail if the call of recursion function is the last operation before the return of function. Let's try to rewrite the previous example by using the tail recursion:

    public static int fibWithTailRec(int n) {
        if (n > 1) return fibIter(1, 1, n - 2);
        else return n;
    }

    private static int fibIter(int prev, int current, int n) {
        if (n == 0) return current;
        else return fibIter(current, prev + current, n - 1);
    }

API of the class hasn't changed: we still have a method that returns the value of the n-th Fibonacci number. We also have a new private method fibIter, which contains tail recursion. Very often, to implement tail recursion, a helper method is used. This method takes a recursive state before the next iteration. Due to this, tail recursive function can be represented as a loop.

    public static int fibWithLoop(int n) {
        if (n <= 1) return n;

        int i = 2;
        int prev = 1;
        int current = 1;
        while (i < n) {
            int next = prev + current;
            prev = current;
            current = next;
            i += 1;
        }
        return current;
    }

This code looks awful and there are many places where you can make a mistake. That's why, the recursive solutions are often more efficient due to their brevity and readability.

Let's check our functions under a pressure and compute the 10000th element of the sequence. Unfortunately, both methods - fib and fibWithTailRec - will throw java.lang.StackOverflowError. This means that recursions in Java should be used carefully. And don’t forget about possibility to use the loop :)

Scala rescue ranger

What about recursion in other JVM languages? Let’s see how Scala works with recursive functions. At first, we will take a look at some simple Scala recursion examples and rewrite both recursive code implementations which calculate the Fibonacci numbers. We will try to code similar implementation of the same functions:

  def fib(n: Int): Int =
    if (n > 1) fib(n - 1) + fib(n - 2)
    else n

  def fibWithTailRec(n: Int): Int =
    if (n > 1) fibIter(1, 1, n - 2)
    else n

  // Function fibIter is not inner for illustrative purpose.
  private def fibIter(prev: Int, current: Int, n: Int): Int = 
    if (n == 0) current
    else fibIter(current, prev + current, n - 1)

Now, we will calculate the 10000th Fibonacci number once again but this time on Scala. It turns out, that the tail-recursive version has successfully coped with the task. What is the difference between almost "the same" code with tail recursion in Java and Scala? Let's look at the bytecode of these methods (javap -c) and check what the compilers actually do:

java version:
  public static int fibIter(int, int, int);
    Code:
       0: iload_2
       1: ifne          6   // check the end condition
       4: iload_1
       5: ireturn
       6: iload_1           // push current to stack
       7: iload_0
       8: iload_1
       9: iadd              // push prev + current to stack
      10: iload_2
      11: iconst_1
      12: isub              // push n - 1 to stack
      13: invokestatic  #2  // call method  fibIter:(III)I
      16: ireturn           // return the result of call
scala version:
  public int fibIter(int, int, int);
    Code:
       0: iload_3
       1: iconst_0
       2: if_icmpne     7   // check the end condition
       5: iload_2
       6: ireturn
       7: iload_2           // push current to stack
       8: iload_1
       9: iload_2
      10: iadd              // push prev + current to stack
      11: iload_3
      12: iconst_1
      13: isub              // push n - 1 to stack
      14: istore_3
      15: istore_2
      16: istore_1          // save everything in the local variables
      17: goto          0   // go to the top of the iteration

The Scala compiler understood that it was the tail recursion and transformed a recursive call of function into the loop, which, as we could see earlier, is not that easy to write by yourself.

The world will never be same again

So, we have grasped the benefits of the Scala tail recursion. It's time to use this knowledge! In the world of functional programming, no program can exist without lists. Let's try to implement a fold - one of the basic functions with lists (Scala offers us two versions - left and right):

  def foldLeft[A, B](seq: Seq[A], z: B)(f: (B, A) => B): B =
    seq match {
      case Nil => z
      case x :: xs => foldLeft(xs, f(z, x))(f)
    }

  def foldRight[A, B](seq: Seq[A], z: B)(f: (A, B) => B): B =
    seq match {
      case Nil => z
      case x :: xs => f(x, foldRight(xs, z)(f))
    }

As we can see, both functions are recursive. But what a pity - foldRight has no tail recursion and therefore can cause a stack overflow. There is a quite simple solution, which certainly requires some efforts, but at the same time doesn’t rise the algorithm complexity. We will write the foldRight function via foldLeft:

  def foldRightUsingFoldLeft[A, B](seq: Seq[A], z: B)(f: (A, B) => B): B =
    foldLeft(seq.reverse, z)((b, a) => f(a, b))

To determine whether a function is optimized, the @tailrec annotation can be used. Some developers mistakenly believe that this command instructs the compiler to perform tail recursion in a method, but this is not entirely accurate (and not always possible). Instead, the annotation checks if the annotated function contains tail recursion, and if it does not, a compilation error is triggered. Several reasons for this can be easily identified, including:

  • if the method can be overridden in subclasses because it is not a final method,
  • if the method does not call itself,
  • or if it is not a tail recursion.

By following these guidelines, developers can optimize their code and enjoy programming in Scala.