Automatic differentiation

Mar 8, 2019   #FP  #Haskell 

At Pramati Chennai, we’ve been having a series of sessions on math. The purpose is to try and connect many concepts usually considered as separate. We’re currently on a track to understand functions of multiple variables and their calculus.

Along the way, I thought it might be a good idea to try and introduce automatic differentiation in a programmatic way so people have a taste of how to precisely capture their ideas. This is not an attempt to implement the AD algorithms (at least not just yet), but to make the key idea concrete.

We started with functions of a single variable. Though I worked out a version in Javascript for the group, I’d done it first in Haskell to get some clarity myself, which is what I want to share here. First we’ll do functions of one variable and then move on to multi-variate stuff.

Note: The code described in this post is available on github.

The main purpose of this exercise is to be able to calculate the derivative of a function given as an expression. To do this, we need to be able to keep track, at each point in th evaluation of the expression, both the expression we’re evaluating as well as its derivative.

What I mean is, that our “differentiable function” is a tuple of the form -

$$(f(x), f'(x))$$

The interesting thing here is that the second part of the tuple, the derivative, is itself a function. We’d like to be able to differentiate that too to get the second derivative. Therefore, what we really want is for the tuple to be defined as a recursive type.

data DF x = DF (x -> x) (DF x)

where we use DF as an abbreviation for “differentiable function”.

In order to be able to calculate the value of this function at \(x\), we’ll pick the first part of the tuple. To calculate the value of the derivative at \(x\), we’ll use the second part.

f (DF fx _) = fx
df (DF _ dfx) = dfx

So now, if a is a given differentiable function, we can evaluate it at \(x\) using f a x and the value of its derivative using f (df a) x.

Due to the recursive nature of the definition, we can also define a recursive function for the n-th derivative -

dfn 0 = id
dfn n = dfn (n-1) . df

The domain \(x\) is simply the identity function.

x = DF id 1

As numbers

In order to make such a straightforward definition of \(x\) possible, we’ll need to have the ability to treat such differentiable functions as ordinary numbers, which we can.

instance (Real n) => Num (DF n) where
    -- The sum rule
    a + b = DF (\x -> f a x + f b x) (df a + df b)
    -- The product rule
    a * b = DF (\x -> f a x * f b x) (a * df b + df a * b)

    negate a = DF (negate . f a) (negate (df a))
    abs a = DF (abs . f a) (df a * signum a)
    signum a = DF (signum . f a) 0 -- Ignore the anomaly at 0

    -- Enable conversion from plain number constants to 
    -- differentiable functions. Now the definition for 
    -- `x` above becomes possible.
    fromInteger x = DF (\_ -> fromInteger x) 0

In the above definitions, notice that we’ve defined the derivative part by recursively referring to the ordinary part. This is possible to express so simply in Haskell due to its laziness. You’ll need some extra machinery in eager languages like Javascript to express this idea.

This looks straightforward but is so cool that it is worth dissecting at least one of the above expressions to understand how it is constructed. Let’s take the sum rule first.

The sum rule explained

a + b = DF (\x -> f a x + f b x) (df a + df b)

The + being defined on the LHS is the operator of addition for the type DF x. This means both a and b are of type DF x.

The first argument of the DF constructor is straight forward. It specifies a function that takes a concrete value, passes it down to both a and b and adds the concrete results and presents it as the result of the function.

The second part is more interesting. Firstly, it is expected to be a differentiable function itself - i.e. of type DF x. So we can use any (correct) expression that results in a DF x type here. Since we know that \((a+b)' = a' + b'\), we can use the fact that df a is basically \(a'\) to simply write df a + df b.

The key here is that the + being used here is the very + that we’re attempting to define. Thanks to Haskell’s laziness, this doesn’t result in a non-terminating recursion since the second part of the DF construction is evaluated only when it is needed.

The product rule explained

a * b = DF (\x -> f a x * f b x) (df a * b + a * df b)

The product rule is a bit more intricate than the sum rule, but it is essentially the same idea where we’re expressing our knowledge of \((ab)' = a’b + ab'\).

This time, not only are we relying on the currently-being-defined * operator for DF x types, we also make use of the + operator we just defined.

As rationals

We want to make it possible to write the normal mathematical expressions at least. So the minimum we’ll need is the ability to divide.

instance (Real n, Fractional n) => Fractional (DF n) where
    -- Add the ability to convert rational numbers to
    -- constant differentiable functions.
    fromRational x = DF (\_ -> fromRational x) 0

    -- Specifying how to calculate the reciprocal automatically
    -- defines the operation of division as a / b = (recip b) * a
    recip a = DF (\x -> 1 / f a x) (- df a / (a * a))

Notice again that we’re using just what we already know about how to differentiate functions of a single variable and encoding that knowledge in the derivative part.

As common scientific functions

We can continue to do the same thing for all the normal scientific functions.

instance (Real n, Floating n) => Floating (DF n) where
    pi = DF (\_ -> pi) 0
    exp a = DF (exp . f a) (df a * exp a)
    log a = DF (log . f a) (df a / a)
    sin a = DF (sin . f a) (df a * cos a)
    cos a = DF (cos . f a) (- df a * sin a)
    asin a = DF (asin . f a) (df a / sqrt (1 - a * a))
    acos a = DF (acos . f a) (- df a / sqrt (1 - a * a))
    atan a = DF (atan . f a) (df a / (1 + a * a))
    sinh a = DF (sinh . f a) (df a * cosh a)
    cosh a = DF (cosh . f a) (df a * sinh a)
    asinh a = DF (asinh . f a) (df a / sqrt (1 + a * a))
    acosh a = DF (acosh . f a) (df a / sqrt (a * a - 1))
    atanh a = DF (atanh . f a) (df a / (1 + a * a))

Note: The definitions for sin and cos are mutually recursive, as are the definitions for sinh and cosh.

Getting it all together

For a start, that is basically it. You can load up all of the above code into a Haskell module.

With just that much, you can calculate as many derivatives as you want for an expression like -

let silly = 2 * x + x * x * sin x

You can, for instance, evaluate the 5th derivative of the silly function at \(x = 2.34\) using f (dfn 5 silly) 2.34.

Challenge: Try to express piecewise functions using this mechanism. By that, I mean the expression to be used to calculate the function for \(x < x_0\) is different from the one to use to calculate for \(x_0 <= x < x_1\), yet another for \(x_1 <= x < x_2\) and so on.

Next up: Differentiable functions of multiple variables.