At Pramati Chennai, we’ve been having a series of sessions on math. The purpose is to try and connect many concepts usually considered as separate. We’re currently on a track to understand functions of multiple variables and their calculus.
Along the way, I thought it might be a good idea to try and introduce automatic differentiation in a programmatic way so people have a taste of how to precisely capture their ideas. This is not an attempt to implement the AD algorithms (at least not just yet), but to make the key idea concrete.
Note: The code described in this post is available on github.
The main purpose of this exercise is to be able to calculate the derivative of a function given as an expression. To do this, we need to be able to keep track, at each point in th evaluation of the expression, both the expression we’re evaluating as well as its derivative.
What I mean is, that our “differentiable function” is a tuple of the form -
The interesting thing here is that the second part of the tuple, the derivative, is itself a function. We’d like to be able to differentiate that too to get the second derivative. Therefore, what we really want is for the tuple to be defined as a recursive type.
data DF x = DF (x -> x) (DF x)
where we use
DF as an abbreviation for “differentiable function”.
In order to be able to calculate the value of this function at \(x\), we’ll pick the first part of the tuple. To calculate the value of the derivative at \(x\), we’ll use the second part.
f (DF fx _) = fx df (DF _ dfx) = dfx
So now, if
a is a given differentiable function, we can evaluate it at \(x\)
f a x and the value of its derivative using
f (df a) x.
Due to the recursive nature of the definition, we can also define a recursive function for the n-th derivative -
dfn 0 = id dfn n = dfn (n-1) . df
The domain \(x\) is simply the identity function.
x = DF id 1
In order to make such a straightforward definition of \(x\) possible, we’ll need to have the ability to treat such differentiable functions as ordinary numbers, which we can.
instance (Real n) => Num (DF n) where -- The sum rule a + b = DF (\x -> f a x + f b x) (df a + df b) -- The product rule a * b = DF (\x -> f a x * f b x) (a * df b + df a * b) negate a = DF (negate . f a) (negate (df a)) abs a = DF (abs . f a) (df a * signum a) signum a = DF (signum . f a) 0 -- Ignore the anomaly at 0 -- Enable conversion from plain number constants to -- differentiable functions. Now the definition for -- `x` above becomes possible. fromInteger x = DF (\_ -> fromInteger x) 0
This looks straightforward but is so cool that it is worth dissecting at least one of the above expressions to understand how it is constructed. Let’s take the sum rule first.
The sum rule explained
a + b = DF (\x -> f a x + f b x) (df a + df b)
+ being defined on the LHS is the operator of addition for the type
DF x. This means both
b are of type
The first argument of the
DF constructor is straight forward. It specifies a
function that takes a concrete value, passes it down to both
adds the concrete results and presents it as the result of the function.
The second part is more interesting. Firstly, it is expected to be a differentiable
function itself - i.e. of type
DF x. So we can use any (correct) expression
that results in a
DF x type here. Since we know that \((a+b)' = a' + b'\),
we can use the fact that
df a is basically \(a'\) to simply write
df a + df b.
The key here is that the
+ being used here is the very
+ that we’re
attempting to define. Thanks to Haskell’s laziness, this doesn’t result in a
non-terminating recursion since the second part of the
DF construction is
evaluated only when it is needed.
The product rule explained
a * b = DF (\x -> f a x * f b x) (df a * b + a * df b)
The product rule is a bit more intricate than the sum rule, but it is essentially the same idea where we’re expressing our knowledge of \((ab)' = a’b + ab'\).
This time, not only are we relying on the currently-being-defined
DF x types, we also make use of the
+ operator we just defined.
We want to make it possible to write the normal mathematical expressions at least. So the minimum we’ll need is the ability to divide.
instance (Real n, Fractional n) => Fractional (DF n) where -- Add the ability to convert rational numbers to -- constant differentiable functions. fromRational x = DF (\_ -> fromRational x) 0 -- Specifying how to calculate the reciprocal automatically -- defines the operation of division as a / b = (recip b) * a recip a = DF (\x -> 1 / f a x) (- df a / (a * a))
Notice again that we’re using just what we already know about how to differentiate functions of a single variable and encoding that knowledge in the derivative part.
As common scientific functions
We can continue to do the same thing for all the normal scientific functions.
instance (Real n, Floating n) => Floating (DF n) where pi = DF (\_ -> pi) 0 exp a = DF (exp . f a) (df a * exp a) log a = DF (log . f a) (df a / a) sin a = DF (sin . f a) (df a * cos a) cos a = DF (cos . f a) (- df a * sin a) asin a = DF (asin . f a) (df a / sqrt (1 - a * a)) acos a = DF (acos . f a) (- df a / sqrt (1 - a * a)) atan a = DF (atan . f a) (df a / (1 + a * a)) sinh a = DF (sinh . f a) (df a * cosh a) cosh a = DF (cosh . f a) (df a * sinh a) asinh a = DF (asinh . f a) (df a / sqrt (1 + a * a)) acosh a = DF (acosh . f a) (df a / sqrt (a * a - 1)) atanh a = DF (atanh . f a) (df a / (1 + a * a))
Note: The definitions for
cosare mutually recursive, as are the definitions for
Getting it all together
For a start, that is basically it. You can load up all of the above code into a Haskell module.
With just that much, you can calculate as many derivatives as you want for an expression like -
let silly = 2 * x + x * x * sin x
You can, for instance, evaluate the 5th derivative of the
silly function at
\(x = 2.34\) using
f (dfn 5 silly) 2.34.
Challenge: Try to express piecewise functions using this mechanism. By that, I mean the expression to be used to calculate the function for \(x < x_0\) is different from the one to use to calculate for \(x_0 <= x < x_1\), yet another for \(x_1 <= x < x_2\) and so on.
Next up: Differentiable functions of multiple variables.