There are many excellent AI papers and tutorials that explain the attention pattern in Large Language Models. But this essentially simple pattern is often obscured by implementation details and optimizations. In this post I will try to cut to the essentials.

In a nutshell, the attention machinery tries to get at a meaning of a word (more precisely, a token). This should be easy in principle: we could just look it up in the dictionary. For instance, the word “triskaidekaphobia” means “extreme superstition regarding the number thirteen.” Simple enough. But consider the question: What does “it” mean in the sentence “look it up in the dictionary”? You could look up the word “it” in the dictionary, but that wouldn’t help much. More ofthen than not, we guess the meaning of words from their context, sometimes based on a whole conversation.

The attention mechanism is a way to train a language model to be able to derive a meaning of a word from its context. 

The first step is the rough encoding of words (tokens) as multi-dimensional vectors. We’re talking about 12,288 dimensions for GPT-3. That’s an enormous semantic space. It means that you can adjust the meaning of a word by nudging it in thousands of different directions by varying amounts (usually 32-bit floating point numbers). This “nudging” is like adding little footnotes to a word, explaining what was its precise meaning in a particular situation. 

(Note: In practice, working with such huge vectors would be prohibitively expensive, so they are split between multiple heads of attention and processed in parallel. For instance, GPT-3 has 96 heads of attention, and each head works within a 128-dimensional vector space.)

We are embedding the input word as a 12,288-dimensional vector \vec E, and we are embedding N words of context as 12,288-dimensional vectors, \vec E_i, i = 1..N (for GPT-3, N=2048 tokens). Initially, the mapping from words to embeddings is purely random, but as the training progresses, it is gradually refined using backpropagation.

(Note: In most descriptions, you’ll see a whole batch of words begin embedded all at once, and the context often being identical with that same batch–but this is just an optimization.)

The goal is to refine the embedding \vec E by adding to it a small delta \Delta \vec E that is derived from the context \vec E_i.

This is usually described as the vector \vec E querying the context, and the context responding to this query.

First, we apply a gigantic trainable 12,288 by 12,288 matrix W_Q to the embedding \vec E, to get the query vector:

\vec Q = W_Q \vec E

You may think of \vec Q as the question: “Who am I with respect to the Universe?” The entries in the matrix W_Q are the weights that are learned by running the usual backpropagation.

We then apply another trainable matrix W_K to every vector \vec E_i of the context to get a batch of key vectors:

\vec K_i = W_K \vec E_i

You may think of \vec K_i as a possible response from the i‘th component of the context to all kinds of questions.

The next step is crucial: We calculate how relevant each element of the context is to our word. In other words, we are focusing our attention on particular elements of the context. We do it by taking N scalar products \vec K_i \cdot \vec Q.

A scalar product can vary widely. A large positive number means that the i‘th element of the context is highly relevant to the meaning of \vec E. A large negative number, on the other hand, means that it has very little to contribute.

What we really need is to normalize these numbers to a range between zero (not relevant) and one (extremely relevant) and make sure that they all add up to one. This is normally done using the softmax procedure: we first raise e to the power of the given number to make the result non-negative:

a_i = exp (\vec K_i \cdot \vec Q)

We then divide it by the total sum, to normalize it:

A_i = \dfrac {a_i}{ \sum_{i = 1}^N a_i}

These are our attention weights. (Note: For efficiency, before performing the softmax, we divide all numbers by the square root of the dimension of the vector space \sqrt{d_k}.)

Attention weights tell us how much each element of the context can contribute to the meaning of \vec E, but we still don’t know what it contributes. We figure this out by multiplying each element of the context by yet another trainable matrix, W_V. The result is a batch of value vectors \vec V_i:

\vec V_i = W_V \vec E_i

We now accumulate all these contribution, weighing them by their attention weights A_i. We get the adjusted meaning of our original embedding \vec E (this step is called the residual connection):

\vec E' = \vec E + \sum_{i = 1}^N A_i \vec V_i

The result \vec E' is infused with the additional information gained from a larger context. It’s closer to the actual meaning. For instance, the word “it” would be nudged towards the noun that it stands for.

And that’s essentially the basic block of the attention system. The rest is just optimization and composition.

One major optimization is gained by processing a whole window of tokens at once (2048 tokens for GPT-3). In particular, in the self-attention pattern we use the same batch of tokens for both the input and the context. In general, though, the context can be distinct from the input. It could, for instance, be a sentence in a different language.

Another optimization is the partitioning of all three vectors, \vec Q, \vec K_i, and \vec V_i between the heads of attention. Each of these heads operates inside a smaller subspace of the vector space. For GPT-3 these are 128-dimensional subspaces. The heads produce smaller-dimensional deltas, \Delta \vec E, which are concatenated into larger vectors, which are then added to the original embeddings through residual connection.

In GPT-3, each multi-headed attention block is followed by a multi-layer perceptron, MLP, and this transformation is repeated 96 times. Most steps in an LLM are just linear algebra; except for softmax and the activation function in the MLP, which are non-linear.

All this work is done just to produce the next word in a sentence.


Introduction

Neural networks are an example of composable systems, so it’s no surprise that they can be modeled in category theory, which is the ultimate science of composition. Moreover, the categorical ideas behind neural networks can be immediately implemented and tested in a programming language. In this post I will present the Haskell implementation of parametric lenses, generalize them to pre-lenses and introduce their profunctor representation. Using the profunctor representation I will build a working multi-layer perceptron.

In the second part of this post I will introduce the bicategory \mathbf{PreLens} of pre-lenses and the bicategory of triple Tambara profunctors and show how they related to pre-lenses.

Complete Haskell implementation is available on gitHub, where you can also find the PDF version of this post, complete with the categorical picture.

Haskell Implementation

Every component of a neural network can be thought of as a system that transform input to output, and whose action depends on some parameters. In the language of neural networsks, this is called the forward pass. It takes a bunch of parameters p, combines it with the input s, and produces the output a. It can be described by a Haskell function:

fwd :: (p, s) -> a

But the real power of neural networks is in their ability to learn from mistakes. If we don’t like the output of the network, we can nudge it towards a better solution. If we want to nudge the output by some da, what change dp to the parameters should we make? The backward pass partitions the blame for the perceived error in direct proportion to the impact each parameter had on the result.

Because neural networks are composed of layers of neurons, each with their own sets or parameters, we might also ask the question: What change ds to this layer’s inputs (which are the outputs of the previous layer) should we make to improve the result? We could then back-propagate this information to the previous layer and let it adjust its own parameters. The backward pass can be described by another Haskell function:

bwd :: (p, s, da) -> (dp, ds)

The combination of these two functions forms a parametric lens:

data PLens a da p dp s ds = 
  PLens { fwd :: (p, s) -> a
        , bwd :: (p, s, da) -> (dp, ds) }

In this representation it’s not immediately obvious how to compose parametric lenses, so I’m going to present a variety of other representations that may be more convenient in some applications.

Existential Parametric Lens

Notice that the backward pass re-uses the arguments (p, s) of the forward pass. Although some information from the forward pass is needed for the backward pass, it’s not always clear that all of it is required. It makes more sense for the forward pass to produce some kind of a care package to be delivered to the backward pass. In the simplest case, this package would just be the pair (p, s). But from the perspective of the user of the lens, the exact type of this package is an internal implementation detail, so we might as well hide it as an existential type m. We thus arrive at a more symmetric representation:

data ExLens a da p dp s ds = 
  forall m . ExLens ((p, s)  -> (m, a))  
                    ((m, da) -> (dp, ds))

The type m is often called the residue of the lens.

These existential lenses can be composed in series. The result of the composition is parameterized by the product (a tuple) of the original parameters. We’ll see it more clearly in the next section.

But since the product of types is associative only up to isomorphism, the composition of parametric lenses is associative only up to isomorphism.

There is also an identity lens:

identityLens :: ExLens a da () () a da
identityLens = ExLens id id

but, again, the categorical identity laws are satisfied only up to isomorphism. This is why parametric lenses cannot be interpreted as hom-sets in a traditional category. Instead they are part of a bicategory that arises from the \mathbf{Para} construction.

Pre-Lenses

Notice that there is still an asymmetry in the treatment of the parameters and the residues. The parameters are accumulated (tupled) during composition, while the residues are traced over (categorically, an existential type is described by a coend, which is a generalized trace). There is no reason why we shouldn’t accumulate the residues during composition and postpone the taking of the trace untill the very end.

We thus arrive at a fully symmetrical definition of a pre-lens:

data PreLens a da m dm p dp s ds =
  PreLens ((p, s)   -> (m, a))
          ((dm, da) -> (dp, ds))

We now have two separate types: m describing the residue, and dm describing the change of the residue.

Screenshot 2024-03-22 at 12.19.58

If all we need at the end is to trace over the residues, we’ll identify the two types.

Notice that the role of parameters and residues is reversed between the forward and the backward pass. The forward pass, given the parameters and the input, produces the output plus the residue. The backward pass answers the question: How should we nudge the parameters and the inputs (dp, ds) if we want the residues and the outputs to change by (dm, da). In neural networks this will be calculated using gradient descent.

The composition of pre-lenses accumulates both the parameters and the residues into tuples:

preCompose ::
    PreLens a' da' m dm p dp s ds -> 
    PreLens a da n dn q dq a' da' ->
    PreLens a da (m, n) (dm, dn) (q, p) (dq, dp) s ds
preCompose (PreLens f1 g1) (PreLens f2 g2) = PreLens f3 g3
  where
    f3 = unAssoc . second f2 . assoc . first sym . 
         unAssoc . second f1 . assoc
    g3 = unAssoc . second g1 . assoc . first sym . 
         unAssoc . second g2 . assoc

We use associators and symmetrizers to rearrange various tuples. Notice the separation of forward and backward passes. In particular, the backward pass of the composite lens depends only on backward passes of the composed lenses.

There is also an identity pre-lens:

idPreLens :: PreLens a da () () () () a da
idPreLens = PreLens id id

Pre-lenses thus form a bicategory that combines the \mathbf{Para} and the \mathbf{coPara} constructions in one.

There is also a monoidal structure in this category induced by parallel composition. In parallel composition we tuple the respective inputs and outputs, as well as the parameters and residues, both in the forward and the backward passes.

The existential lens can be obtained from the pre-lens at any time by tracing over the residues:

data ExLens a da p dp s ds = 
  forall m. ExLens (PreLens a da m m p dp s ds)

Notice however that the tracing can be performed after we are done with all the (serial and parallel) compositions. In particular, we could dedicate one pipeline to perform forward passes, gathering both parameters and residues, and then send this data over to another pipeline that performs backward passes. The data is produced and consumed in the LIFO order.

Pre-Neuron

As an example, let’s implement the basic building block of neural networks, the neuron. In what follows, we’ll use the following type synonyms:

type D = Double
type V = [D]

A neuron can be decomposed into three mini-layers. The first layer is the linear transformation, which calculates the scalar product of the input vector and the vector of parameters:

a = \sum_{i = 1}^n p_i \times s_i

It also produces the residue which, in this case, consists of the tuple (V, V) of inputs and parameters:

fw :: (V, V) -> ((V, V), D)
fw (p, s) = ((s, p), sumN n $ zipWith (*) p s)

The backward pass has the general signature:

bw :: ((dm, da) -> (dp, ds))

Because we’re eventually going to trace over the residues, we’ll use the same type for dm as for m. And because we are going to do arithmetic over the parameters, we reuse the type of p for the delta dp. Thus the signature of the backward pass is:

bw :: ((V, V), D) -> (V, V)

In the backward pass we’ll encode the gradient descent. The steepest gradient direction and slope is given by the partial derivatives:

\frac{\partial{ a}}{\partial p_i} = s_i

\frac{\partial{ a}}{\partial s_i} = p_i

We multiply them by the desired change in the output da:

dp = fmap (da *) s
ds = fmap (da *) p

Here’s the resulting lens:

linearL :: Int -> PreLens D D (V, V) (V, V) V V V V
linearL n = PreLens fw bw
  where
    fw :: (V, V) -> ((V, V), D)
    fw (p, s) = ((s, p), sumN n $ zipWith (*) p s)
    bw :: ((V, V), D) -> (V, V)
    bw ((s, p), da) = (fmap (da *) s
                      ,fmap (da *) p)

The linear transformation is followed by a bias, which uses a single number as the parameter, and generates no residue:

biasL :: PreLens D D () () D D D D
biasL = PreLens fw bw 
  where 
    fw :: (D, D) -> ((), D)
    fw (p, s) = ((), p + s)
    -- da/dp = 1, da/ds = 1
    bw :: ((), D) -> (D, D)
    bw (_, da) = (da, da)

Finally, we implement the non-linear activation layer using the tanh function:

activL :: PreLens D D D D () () D D
activL = PreLens fw bw
  where
    fw (_, s) = (s, tanh s)
    -- da/ds = 1 + (tanh s)^2
    bw (s, da)= ((), da * (1 - (tanh s)^2))

A neuron with m inputs is a composition of the three components, modulo some monoidal rearrangements:

neuronL :: Int -> 
    PreLens D D ((V, V), D) ((V, V), D) Para Para V V
neuronL mIn = PreLens f' b'
  where 
    PreLens f b = 
      preCompose (preCompose (linearL mIn) biasL) activL
    f' :: (Para, V) -> (((V, V), D), D)
    f' (Para bi wt, s) = let (((vv, ()), d), a) = 
        f (((), (bi, wt)), s)
                         in ((vv, d), a)
    b' :: (((V, V), D), D) -> (Para, V)
    b' ((vv, d), da) = let (((), (d', w')), ds) = 
        b (((vv, ()), d), da)
                       in (Para d' w', ds)

The parameters for the neuron can be conveniently packaged into one data structure:

data Para = Para { bias   :: D
                 , weight :: V }

mkPara (b, v) = Para b v
unPara p = (bias p, weight p)

Using parallel composition, we can create whole layers of neurons, and then use sequential composition to model multi-layer neural networks. The loss function that compares the actual output with the expected output can also be implemented as a lens. We’ll perform this construction later using the profunctor representation.

Tambara Modules

As a rule, all optics that have an existential representation also have some kind of profunctor representation. The advantage of profunctor representations is that they are functions, and they compose using function composition.

Lenses, in particular, have a representation using a special category of profunctors called Tambara modules. A vanilla Tambara module is a profunctor p equipped with a family of transformations. It can be implemented as a Haskell class:

class  Profunctor p => Tambara p where
  alpha :: forall a da m. p a da -> p (m, a) (m, da)

The vanilla lens is then represented by the following profunctor-polymorphic function:

type Lens a da s ds = forall p. Tambara p => p a da -> p s ds

A similar representation can be constructed for pre-lenses. A pre-lens, however, has additional dependency on parameters and residues, so the analog of a Tambara module must also be parameterized by those. We need, therefore, a more complex type constructor t that takes six arguments:

t m dm p dp s ds

This is supposed to be a profunctor in three pairs of arguments, s ds, p dp, and dm m. Pro-functoriality in the first two pairs is implemented as two functions, diampS and dimapP. The inverted order in dm m means that t is covariant in m and contravariant in dm, as seen in the unusual type signature of dimapM:

dimapM  :: (m -> m') -> (dm' -> dm) -> 
  t m dm p dp s ds -> t m' dm' p  dp  s  ds

To generalize Tambara modules we first observe that the pre-lens now has two independent residues, m and dm, and the two should transform separately. Also, the composition of pre-lenses accumulates (through tupling) both the residues and the parameters, so it makes sense to use the additional type arguments to TriProFunctor as accumulators. Thus the generalized Tambara module has two methods, one for accumulating residues, and one for accumulating parameters:

class TriProFunctor t => Trimbara t where
  alpha :: t m dm p dp s ds -> 
           t (m1, m) (dm1, dm) p dp (m1, s) (dm1, ds)
  beta  :: t m dm p dp (p1, s) (dp1, ds) -> 
           t m dm (p, p1) (dp, dp1) s ds

These generalized Tambara modules satisfy some coherency conditions.

One can also define natural transformations that are compatible with the new structures, so that Trimbara modules form a category.

The question arises: can this definition be satisfied by an actual non-trivial TriProFunctor? Fortunately, it turns out that a pre-lens itself is an example of a Trimbara module. Here’s the implementation of alpha for a PreLens:

alpha (PreLens fw bw) = PreLens fw' bw'
  where
    fw' (p, (n, s)) = let (m, a) = fw (p, s)
                      in ((n, m), a)
    bw' ((dn, dm), da) = let (dp, ds) = bw (dm, da)
                         in (dp, (dn, ds))

and this is beta:

beta (PreLens fw bw) = PreLens fw' bw'
  where
    fw' ((p, r), s) = let (m, a) = fw (p, (r, s))
                      in (m, a)
    bw' (dm, da) = let (dp, (dr, ds)) = bw (dm, da)
                   in ((dp, dr), ds)

This result will become important in the next section.

TriLens

Since Trimbara modules form a category, we can define a polymorphic function type (a categorical end) over Trimbara modules . This gives us the (tri-)profunctor representation for a pre-lens:

type TriLens a da m dm p dp s ds =
    forall t. Trimbara t => forall p1 dp1 m1 dm1. 
      t m1 dm1 p1 dp1 a da -> 
      t (m, m1) (dm, dm1) (p1, p) (dp1, dp) s ds

Indeed, given a pre-lens we can construct the requisite mapping of Trimbara modules simply by lifting the two functions (the forward and the backward pass) and sandwiching them between the two Tambara structure maps:

toTamb :: PreLens a da m dm p dp s ds -> 
    TriLens a da m dm p dp s ds
toTamb (PreLens fw bw) = beta . dimapS fw bw . alpha

Conversely, given a mapping between Trimbara modules, we can construct a pre-lens by applying it to the identity pre-lens (modulo some rearrangement of tuples using the monoidal right/left unit laws):

fromTamb :: TriLens a da m dm p dp s ds -> 
    PreLens a da m dm p dp s ds
fromTamb f = dimapM runit unRunit $  
             dimapP unLunit lunit $ 
             f idPreLens 

The main advantage of the profunctor representation is that we can now compose two lenses using simple function composition; again, modulo some associators:

triCompose ::
    TriLens b db m dm p dp s ds -> 
    TriLens a da n dn q dq b db ->
    TriLens a da (m, n) (dm, dn) (q, p) (dq, dp) s ds
triCompose f g = dimapP unAssoc assoc . 
                 dimapM unAssoc assoc . 
                 f . g

Parallel composition of TriLenses is also relatively straightforward, although it involves a lot of bookkeeping (see the gitHub implementation).

Training a Neural Network

As a proof of concept, I have implemented and trained a simple 3-layer perceptron.

The starting point is the conversion of the individual components of the neuron from their pre-lens representation to the profunctor representation using toTamb. For instance:

linearT :: Int -> TriLens D D (V, V) (V, V) V V V V
linearT n = toTamb (linearL n)

We get a profunctor representation of a neuron by composing its three components:

neuronT :: Int -> 
  TriLens D D ((V, V), D) ((V, V), D) Para Para V V
neuronT mIn = 
  dimapP (second (unLunit . unPara)) 
         (second (mkPara . lunit)) .
  triCompose (dimapM (first runit) (first unRunit) .
  triCompose (linearT mIn) biasT) activT

With parallel composition of tri-lenses, we can build a layer of neurons of arbitrary width.

layer :: Int -> Int -> 
  TriLens V V [((V, V), D)] [((V, V), D)] [Para] [Para] V V
layer mIn nOut = 
  dimapP (second unRunit) (second runit) .
  dimapM (first lunit) (first unLunit) .
  triCompose (branch nOut) (vecLens nOut (neuronT mIn))

The result is again a tri-lens, and such tri-lenses can be composed in series to create a multi-layer perceptron.

makeMlp :: Int -> [Int] -> 
  TriLens V V -- output
          [[((V, V), D)]] [[((V, V), D)]] -- residues
          [[Para]] [[Para]] -- parameters
          V V -- input

Here, the first integer specifies the number of inputs of each neuron in the first layer. The list [Int] specifies the number of neurons in consecutive layers (which is also the number of inputs of each neuron in the following layer).

The training of a neural network is usually done by feeding it a batch of inputs together with a batch of expected outputs. This can be simply done by arranging multiple perceptrons in parallel and accumulating the parameters for the whole batch.

batchN :: (VSpace dp) => Int -> 
    TriLens  a da m dm p dp s ds -> 
    TriLens [a] [da] [m] [dm] p dp [s] [ds]

To make the accumulation possible, the parameters must form a vector space, hence the constraint VSpace dp.

The whole thing is then capped by a square-distance loss lens that is parameterized by the ground truth values:

lossL :: PreLens D D ([V], [V]) ([V], [V]) [V] [V] [V] [V]
lossL = PreLens fw bw 
  where
    fw (gTruth, s) = 
      ((gTruth, s), sqDist (concat s) (concat gTruth))
    bw ((gTruth, s), da) = (fmap (fmap negate) delta', delta')
      where
        delta' = fmap (fmap (da *)) (zipWith minus s gTruth)

In the next post I will present the categorical version of this construction.