Automatic Differentiation: Higher ranked beings

Mar 13, 2019   #FP  #Haskell 

(Status: Draft)

We saw how we can try to calculate derivates of a function while evaluating the function in the first and second posts on automatic differentiation. Those dealt with only functions of a single real number. While that is illustrative of the approach, it was merely warm up so we can deal with differentiable functions of multiple variables that yield multiple values - i.e. vector valued functions of vectors … or more generally, tensor valued functions of tensors.

To start with, we can treat a vector space (of whatever rank) as a “collection” type which uses indexes to refer to the components of the object along whatever identified bases. To make it concrete, we could treat them as functions from indices to a field (say real numbers). So a type expression for these might be -

type Vec i x = i -> x

We’re in general free to use any representable thing which has the same kind of “mapping indices to a field” characteristic, but we’ll live with the simplest of these - plain functions - for now until we complete a first picture. After we do that, we can ruminate on other implementations.

Reminder: The purpose of these AD posts is to present a running commentary of attempts to understand the structure of applying AD on functions. Efficient algorithms will get the last attention .. if I feel like.

Similar to functions of a single variable, we can define a “differentiable vector function” type as -

data DF i j x = DF (Vec i x -> Vec j x) (i -> DF i j x)

The first part of the DF is clear enough - it is a function that takes a vector and calculates another vector. We’ve kept the idea of an “index” loose enough that these Vec types cover higher ranked vectors - i.e. tensors - as well.

The second part is a bit different because we now need to specify the dimension along which we’re differentiating in order to get a derivative. An alternative would be to encode a function that computes the “directional derivative” - which is the value of the derivative along a given vector argument. I wanted to explore this direction first, so directional derivatives will have to wait :)

Note: The full module’s code is available on github.

The usual accessors are easy -

f (DF f0 _) = f0
df i (DF _ df0) = df0 i

An alternative way of looking at the derivative is as an operation that increases the rank of the beast being differentiated. This is still easy enough to express -

dft :: (Eq i, Eq j) => DF i j x -> DF i (j,i) x
dft a = DF f' df'
    where 
        f' x (j,i) = f (df i a) x j
        df' i = dft (df i a)

Notice there that the “output index” of the final function is a compound one indicating a higher rank.

We’ll need to be able to address individual components of the vectors we’ll be supplying to our functions - similar to how we made x available. We can define a function to help us with that -

dirac :: (Eq i, Num x) => i -> i -> x
dirac i j = if i == j then 1 else 0

v :: (Eq i, Eq j, Real x) => i -> DF i j x
v i = DF (\x _ -> x i) (dirac i)

Basic arithmetic

With just this much, the basic arithmetic of numbers and scalar function applications can be satisfied in a straightforward manner.

-- Basic arithmetic operations
instance (Eq i, Eq j, Real n) => Num (DF i j n) where
    a + b = DF (\x i -> f a x i + f b x i) (\i -> df i a + df i b)
    a * b = DF (\x i -> f a x i * f b x i) (\i -> a * df i b + df i a * b)
    negate a = DF (\x i -> negate (f a x i)) (\i -> negate (df i a))
    abs a = DF (\x i -> abs (f a x i)) (\i -> df i a * signum a)
    signum a = DF (\x i -> signum (f a x i)) (\i -> 0)
    fromInteger x = DF (\_ _ -> fromInteger x) (\_ -> 0)

-- Reciprocal
instance (Real n, Fractional n, Eq i, Eq j) => Fractional (DF i j n) where
    fromRational x = DF (\_ _ -> fromRational x) (\_ -> 0)
    recip a = DF (\x i -> 1 / f a x i) (\i -> - df i a / (a * a))   

-- Scientific functions
instance (Real n, Floating n, Eq i, Eq j) => Floating (DF i j n) where
    pi = DF (\_ _ -> pi) (\_ -> 0)
    exp a = DF (\x i -> exp (f a x i)) (\i -> df i a * exp a)
    log a = DF (\x i -> log (f a x i)) (\i -> df i a / a)
    sin a = DF (\x i -> sin (f a x i)) (\i -> df i a * cos a)
    cos a = DF (\x i -> cos (f a x i)) (\i -> - df i a * sin a)
    asin a = DF (\x i -> asin (f a x i)) (\i -> df i a / sqrt (1 - a * a))
    acos a = DF (\x i -> acos (f a x i)) (\i -> - df i a / sqrt (1 - a * a))
    atan a = DF (\x i -> atan (f a x i)) (\i -> df i a / (1 + a * a))
    sinh a = DF (\x i -> sinh (f a x i)) (\i -> df i a * cosh a)
    cosh a = DF (\x i -> cosh (f a x i)) (\i -> df i a * sinh a)
    asinh a = DF (\x i -> asinh (f a x i)) (\i -> df i a / sqrt (1 + a * a))
    acosh a = DF (\x i -> acosh (f a x i)) (\i -> df i a / sqrt (a * a - 1))
    atanh a = DF (\x i -> atanh (f a x i)) (\i -> df i a / (1 + a * a))

Code branches

-- While using if-then-else is not so straightforward and needs vectors
-- to be processed into scalars before doing that, we can parameterize
-- conditionals using a "region" function. While the region function takes
-- two arguments, it is free to ignore one if it so chooses.
cond :: (Vec i x -> j -> Bool) -> DF i j x -> DF i j x -> DF i j x
cond region a b = DF f' df'
    where
        f' x j = if region x j then f a x j else f b x j
        df' i = cond region (df i a) (df i b)

Outer product

-- There are many kinds of products we can create with vectors.  The outer
-- product increases the rank of the vectors and is a useful operation before
-- many kinds of reductions.
outer :: (Eq i, Eq j, Eq k, Real x) => DF i j x -> DF i k x -> DF i (j,k) x
outer a b = DF f' df'
    where
        f' x (i,j) = f a x i * f b x j
        df' i = outer (df i a) b + outer a (df i b)

Inner product

-- The inner product usually ends up reducing the rank of the input vectors by
-- summing over a part of the index space. To generalize this idea, we just
-- parameterize the index range of the summation into an enumeration function
-- named `dot` in the argument.
inner :: (Eq i, Eq j, Eq k, Eq l, Real x) => (l -> Int -> Maybe (j,k)) -> DF i j x -> DF i k x -> DF i l x
inner dot a b = DF f' df'
    where
        f' x l = sum' x (dot l) 0 0
        sum' x dot ix result = case dot ix of
            Nothing -> result
            Just (j,k) -> sum' x dot (ix+1) (result + f a x j * f b x k)
        df' i = inner dot (df i a) b + inner dot a (df i b)

Collapsing a vector

-- An inner product between two vectors can be computed as a summation
-- reduction of the outer product of the vectors. Here we code up "collapse",
-- which does such a summation reduction. Like `inner`, `collapse` also reduces
-- the rank of the input.
collapse :: (Eq i, Eq j, Eq k, Real x) => (k -> Int -> Maybe j) -> DF i j x -> DF i k x
collapse dot a = DF f' df'
    where
        f' x k = sum' x (dot k) 0 0
        sum' x dot ix result = case dot ix of
            Nothing -> result
            Just j -> sum' x dot (ix+1) (result + f a x j)
        df' i = collapse dot (df i a)

Chain rule

-- Vector fuction composition. The derivative is expressed using the chain rule.
chain :: (Eq i, Eq j, Eq k, Real x) => (Int -> Maybe j) -> DF j k x -> DF i j x -> DF i k x
chain js a b = DF f' df'
    where
        f' x = f a (f b x)
        df' i = inner (dot i) (chain js (dft a) b) (dft b)
        dot i k ix = case js ix of
            Nothing -> Nothing
            Just j -> Just ((k,j),(j,i))

Slicing, dicing and reshaping

type Slice j i = i -> Maybe j

-- A utility to take a slice of a vector. For simplicity, we model the slice
-- operation as a boolean selector over the index space.
slice :: (Eq i, Eq j, Eq k, Real x) => Slice j k -> DF i j x -> DF i k x
slice s a = DF f' df'
    where
        f' x k = case s k of
            Nothing -> 0
            Just j -> f a x j
        df' i = slice s (df i a)

stride start step i = Just (start + i * step)
range min max i = if i >= min && i <= max then Just i else Nothing

-- Sometimes, it is also useful to be able to change the shape of
-- a vector .. which basically means we change the way its dimensions
-- are addressed.
reshape :: (Eq i, Eq j, Eq k, Real x) => (k -> j) -> DF i j x -> DF i k x
reshape shaper a = DF f' df'
    where
        f' x k = f a x (shaper k)
        df' i = reshape shaper (df i a)

Simple convolutions

-- A dead simple notion of convolution as a reduction operation that does not
-- result in reduction of rank unlike the inner product.
conv :: (Eq i, Eq j, Eq k, Real x) => (k -> [(j,k)]) -> DF i j x -> DF i k x -> DF i k x
conv stride kernel a = DF f' df'
    where
        f' x k = sum (map (\(j,k) -> f kernel x j * f a x k) (stride k))
        df' i = conv stride (df i kernel) a + conv stride kernel (df i a)

Thoughts

While most of the numeric calculation aspects seem fine, the way we had to define inner product and convolution leaves much to be desired. Hope to explore better representations via abstraction in these areas.