Functional fun with fibonacci and friends

Arguably the quintessential recursive function is the fibonacci function.

fib(0) = 1
fib(1) = 1
fib(n) = fib(n-1) + fib(n-2)

In Scala it looks like this.

def fib(n: BigInt): BigInt = if (n < 2) 1 else fib(n-2) + fib(n-1)

We assume of course that input n is always a positive integer number. From a mathematical perspective this is a beautiful implementation, concise and elegant. However, something is very wrong with this implementation. Can you spot it?

Exactly: it’s inefficient. Just think about it for a while, without formal proof we can intuitively guess the complexity of this algorithm. It’s equal to the very number it computes!

O(fib(n)) = fib(n)

Simply measure some executions and see for yourself.

import java.lang.System._

def main(args: Array[String]): Unit =   
  for (n <- 1 to 38) {
    val t1 = currentTimeMillis
    val t2 = currentTimeMillis
    println(n, t2 - t1)

Lovely to see how the execution times follow one another in true fibonacci style.

The problem is that intermediate values are computed over and over again in every recursive branch of the algorithm. It would be nice if for a given number x, we could cache fib(x) and re-use that pre-calculated result when we need it calculating y where y > x. This technique is called function memoization.

Function memoization boils down to a higher-order function that takes the original function as input, and returns another function with the caching mechanism built in.

def memoize[A, B](f: A => B) = new (A => B) {

  val cache = scala.collection.mutable.Map[A, B]()

  def apply(x: A): B = cache.getOrElseUpdate(x, f(x))


The Scala implementation is remarkably simple. We accept a generic function f, and return an anonymous function with the same generic types. The cache is a simple HashMap. The function application method consults the cache and returns it if present, else it computes the desired result, stores it in the cache and returns that result. All this is cleverly implemented by the mutable.Map getOrElseUpdate method:

def getOrElseUpdate(key: A, op: => B): B

Note that the second parameter is lazy, the passed function is only evaluated when necessary. In Java this is very awkward to implement because parameter passing always has call-by-value semantics.

So, are we done? Can we speed up our fibonacci function by memoizing it? Let’s find out…

def main(args: Array[String]): Unit = { 
  val fib_memoized = memoize(fib)
  val t1 = currentTimeMillis
  val r  = fib_memoized(38)
  val t2 = currentTimeMillis

  println("result %s calculated in %s milliseconds".format(r, t2-t1))

result 39088169 calculated in 8480 milliseconds

&%#@!! Curses and profanity! It turns out the cache is only consulted once. Why is that? What is wrong? Is it because there is something special about our fibonacci implementation?

Ah, yes…

It’s recursive.

Our problem turns out to be more tricky than initially expected. We somehow need to weave the memoization down the recursive calls of the function. Let’s call a knowledgeable gentleman on stage to help us out here. Please welcome… Haskell Curry!

Haskell Curry, a pioneer in the field of combinatory logic, discovered a fascinating function called the Y combinator. If you’ve got a red pill handy and want to know more about it, the rabbit hole is this way. Pretty mind blowing stuff!

In a nutshell, what is it and how can we use it? The Y combinator is a higher-order function that takes a non-recursive function and returns a version of that function that is recursive.

Hey! Wait a minute! But our fibonacci function is recursive, so we can’t feed it into the Y combinator, right?

Correct. The trick goes as follows: if we could somehow turn our fibonacci function in a non-recursive function, we could use the Y combinator to transform it into a recursive one with memoization built in, we’re done!

Let’s begin with the first step. We need to pull the recursion out of our beloved fibonacci function. This means eliminating the two recursive calls to itself.

def fib(n: BigInt): BigInt = if (n < 2) 1 else fib(n-2) + fib(n-1)

We substitute the recursive call by a new function g: BigInt => BigInt, that we pass to the fib function as an extra argument.

def fib(g: BigInt => BigInt)(n: BigInt): BigInt = if (n < 2) 1 else g(n-2) + g(n-1)

We redefine the memoize function as follows:

def memoize[A, B](f: (A => B) => A => B): (A => B) = {
  val cache = scala.collection.mutable.Map[A, B]()
  def y(f: (A => B) => A => B): (A => B) = 
    a => cache.getOrElseUpdate(a, f(y(f)(_))(a))

The memoize function defines the nested Y combinator function y, which implements the caching strategy. The result of the memoization is again a function equal to the Y combinator applied to our non-recursive function f.

Let’s find out if it does work faster now!

def main(args: Array[String]): Unit = {
  val fib_memoized = memoize(fib)
  val t1 = currentTimeMillis
  val r  = fib_memoized(1000)
  val t2 = currentTimeMillis
  println("result %s calculated in %s milliseconds".format(r, t2-t1))

139373125598767690091902245245323403501 calculated in 73 milliseconds

Err… nope, mr. Spock. Your codes aren’t perfect. Depending on your JVM configuration, this implementation will sooner or later result in a stack overflow. This problem plagues every recursive function that is not tail recursive and therefore cannot be subject to tail call optimization.

In a real-world situation, an imperative implementation with a loop instead of recursion (or almost-recursion with the Y combinator) is obviously the better choice.

Nonetheless, functional programming is powerful and a lot of fun!


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s