Neural Networks



There are many excellent AI papers and tutorials that explain the attention pattern in Large Language Models. But this essentially simple pattern is often obscured by implementation details and optimizations. In this post I will try to cut to the essentials.

In a nutshell, the attention machinery tries to get at a meaning of a word (more precisely, a token). This should be easy in principle: we could just look it up in the dictionary. For instance, the word “triskaidekaphobia” means “extreme superstition regarding the number thirteen.” Simple enough. But consider the question: What does “it” mean in the sentence “look it up in the dictionary”? You could look up the word “it” in the dictionary, but that wouldn’t help much. More ofthen than not, we guess the meaning of words from their context, sometimes based on a whole conversation.

The attention mechanism is a way to train a language model to be able to derive a meaning of a word from its context. 

The first step is the rough encoding of words (tokens) as multi-dimensional vectors. We’re talking about 12,288 dimensions for GPT-3. That’s an enormous semantic space. It means that you can adjust the meaning of a word by nudging it in thousands of different directions by varying amounts (usually 32-bit floating point numbers). This “nudging” is like adding little footnotes to a word, explaining what was its precise meaning in a particular situation. 

(Note: In practice, working with such huge vectors would be prohibitively expensive, so they are split between multiple heads of attention and processed in parallel. For instance, GPT-3 has 96 heads of attention, and each head works within a 128-dimensional vector space.)

We are embedding the input word as a 12,288-dimensional vector \vec E, and we are embedding N words of context as 12,288-dimensional vectors, \vec E_i, i = 1..N (for GPT-3, N=2048 tokens). Initially, the mapping from words to embeddings is purely random, but as the training progresses, it is gradually refined using backpropagation.

(Note: In most descriptions, you’ll see a whole batch of words begin embedded all at once, and the context often being identical with that same batch–but this is just an optimization.)

The goal is to refine the embedding \vec E by adding to it a small delta \Delta \vec E that is derived from the context \vec E_i.

This is usually described as the vector \vec E querying the context, and the context responding to this query.

First, we apply a gigantic trainable 12,288 by 12,288 matrix W_Q to the embedding \vec E, to get the query vector:

\vec Q = W_Q \vec E

You may think of \vec Q as the question: “Who am I with respect to the Universe?” The entries in the matrix W_Q are the weights that are learned by running the usual backpropagation.

We then apply another trainable matrix W_K to every vector \vec E_i of the context to get a batch of key vectors:

\vec K_i = W_K \vec E_i

You may think of \vec K_i as a possible response from the i‘th component of the context to all kinds of questions.

The next step is crucial: We calculate how relevant each element of the context is to our word. In other words, we are focusing our attention on particular elements of the context. We do it by taking N scalar products \vec K_i \cdot \vec Q.

A scalar product can vary widely. A large positive number means that the i‘th element of the context is highly relevant to the meaning of \vec E. A large negative number, on the other hand, means that it has very little to contribute.

What we really need is to normalize these numbers to a range between zero (not relevant) and one (extremely relevant) and make sure that they all add up to one. This is normally done using the softmax procedure: we first raise e to the power of the given number to make the result non-negative:

a_i = exp (\vec K_i \cdot \vec Q)

We then divide it by the total sum, to normalize it:

A_i = \dfrac {a_i}{ \sum_{i = 1}^N a_i}

These are our attention weights. (Note: For efficiency, before performing the softmax, we divide all numbers by the square root of the dimension of the vector space \sqrt{d_k}.)

Attention weights tell us how much each element of the context can contribute to the meaning of \vec E, but we still don’t know what it contributes. We figure this out by multiplying each element of the context by yet another trainable matrix, W_V. The result is a batch of value vectors \vec V_i:

\vec V_i = W_V \vec E_i

We now accumulate all these contribution, weighing them by their attention weights A_i. We get the adjusted meaning of our original embedding \vec E (this step is called the residual connection):

\vec E' = \vec E + \sum_{i = 1}^N A_i \vec V_i

The result \vec E' is infused with the additional information gained from a larger context. It’s closer to the actual meaning. For instance, the word “it” would be nudged towards the noun that it stands for.

And that’s essentially the basic block of the attention system. The rest is just optimization and composition.

One major optimization is gained by processing a whole window of tokens at once (2048 tokens for GPT-3). In particular, in the self-attention pattern we use the same batch of tokens for both the input and the context. In general, though, the context can be distinct from the input. It could, for instance, be a sentence in a different language.

Another optimization is the partitioning of all three vectors, \vec Q, \vec K_i, and \vec V_i between the heads of attention. Each of these heads operates inside a smaller subspace of the vector space. For GPT-3 these are 128-dimensional subspaces. The heads produce smaller-dimensional deltas, \Delta \vec E, which are concatenated into larger vectors, which are then added to the original embeddings through residual connection.

In GPT-3, each multi-headed attention block is followed by a multi-layer perceptron, MLP, and this transformation is repeated 96 times. Most steps in an LLM are just linear algebra; except for softmax and the activation function in the MLP, which are non-linear.

All this work is done just to produce the next word in a sentence.


I will now provide the categorical foundation of the Haskell implementation from the previous post. A PDF version that contains both parts is also available.

The Para Construction

There’s been a lot of interest in categorical foundations of deep learning. The basic idea is that of a parametric category, in which morphisms are parameterized by objects from a monoidal category \mathcal P:

Screenshot 2024-03-24 at 15.00.20
Here, p is an object of \mathcal P.

When two such morphisms are composed, the result is parameterized by the tensor product of the parameters.

Screenshot 2024-03-24 at 15.00.34

An identity morphism is parameterized by the monoidal unit I.

If the monoidal category \mathcal P is not strict, the parametric composition and identity laws are not strict either. They are satisfied up to associators and unitors of \mathcal P. A category with lax composition and identity laws is called a bicategory. The 2-cells in a parametric bicategory are called reparameterizations.

Of particular interest are parameterized bicategories that are built on top of actegories. An actegory \mathcal C is a category in which we define an action of a monoidal category \mathcal P:

\bullet \colon \mathcal P \times \mathcal C \to C

satisfying some obvious coherency conditions (unit and composition):

I \bullet c \cong c

p \bullet (q \bullet c) \cong (p \otimes q) \bullet c

There are two basic constructions of a parametric category on top of an actegory called \mathbf{Para} and \mathbf{coPara}. The first constructs parametric morphisms from a to b as f_p = p \bullet a \to b, and the second as g_p = a \to p \bullet b.

Parametric Optics

The \mathbf{Para} construction can be extended to optics, where we’re dealing with pairs of objects from the underlying category (or categories, in the case of mixed optics). The parameterized optic is defined as the following coend:

O \langle a, da \rangle \langle p, dp \rangle \langle s, ds \rangle = \int^{m} \mathcal C (p \bullet s, m \bullet a) \times \mathcal C (m \bullet da, dp \bullet ds)

where the residues m are objects of some monoidal category \mathcal M, and the parameters \langle p, dp \rangle come from another monoidal category \mathcal P.

In Haskell, this is exactly the existential lens:

data ExLens a da p dp s ds = 
  forall m . ExLens ((p, s)  -> (m, a))  
                    ((m, da) -> (dp, ds))

There is, however, a more general bicategory of pre-optics, which underlies existential optics. In it, both the parameters and the residues are treated symmetrically.

The PreLens Bicategory

Pre-optics break the feedback loop in which the residues from the forward pass are fed back to the backward pass. We get the following formula:

\begin{aligned}O & \langle a, da \rangle \langle m, dm \rangle \langle p, dp \rangle \langle s, ds \rangle = \\  &\mathcal C (p \bullet s, m \bullet a) \times \mathcal C (dm \bullet da, dp \bullet ds)  \end{aligned}

We interpret this as a hom-set from a pair of objects \langle s, ds \rangle in \mathcal C^{op} \times C to the pair of objects \langle a, da \rangle also in \mathcal C^{op} \times C, parameterized by a pair \langle m, dm \rangle in \mathcal M \times \mathcal M^{op} and a pair \langle p, dp \rangle from \mathcal P^{op} \times \mathcal P.

To simplify notation, I’ll use the bold \mathbf C for the category \mathcal C^{op} \times \mathcal C , and bold letters for pairs of objects and (twisted) pairs of morphisms. For instance, \bold f \colon \bold a \to \bold b is a member of the hom-set \mathbf C (\bold a, \bold b) represented by a pair \langle f \colon a' \to a, g \colon b \to b' \rangle.

Similarly, I’ll use the notation \bold m \bullet \bold a to denote the monoidal action of \mathcal M^{op} \times \mathcal M on \mathcal C^{op} \times \mathcal C:

\langle m, dm \rangle \bullet \langle a, da \rangle = \langle m \bullet a, dm \bullet da \rangle

and the analogous action of \mathcal P^{op} \times \mathcal P.

In this notation, the pre-optic can be simply written as:

O\; \bold a\, \bold m\, \bold p\, \bold s = \bold C (\bold m \bullet \bold a, \bold p \bullet \bold b)

and an individual morphism as a triple:

(\bold m, \bold p, \bold f \colon \bold m \bullet \bold a \to \bold p \bullet \bold b)

Pre-optics form hom-sets in the \mathbf{PreLens} bicategory. The composition is a mapping:

\mathbf C (\bold m \bullet \bold b, \bold p \bullet \bold c) \times \mathbf C (\bold n \bullet \bold a, \bold q \bullet \bold b) \to \mathbf C (\bold (\bold m \otimes \bold n) \bullet \bold a, (\bold q \otimes \bold p) \bullet \bold c)

Indeed, since both monoidal actions are functorial, we can lift the first morphism by (\bold q \bullet -) and the second by (\bold m \bullet -):

\mathbf C (\bold m \bullet \bold b, \bold p \bullet \bold c) \times \mathbf C (\bold n \bullet \bold a, \bold q \bullet \bold b) \xrightarrow{(\bold q \bullet) \times (\bold m \bullet)}

\mathbf C (\bold q \bullet \bold m \bullet \bold b, \bold q \bullet \bold p \bullet \bold c) \times \mathbf C (\bold m \bullet \bold n \bullet \bold a,\bold m \bullet \bold q \bullet \bold b)

We can compose these hom-sets in \mathbf C, as long as the two monoidal actions commute, that is, if we have:

\bold q \bullet \bold m \bullet \bold b \to \bold m \bullet \bold q \bullet \bold b

for all \bold q, \bold m, and \bold b.
The identity morphism is a triple:

(\bold 1, \bold 1, \bold{id} )

parameterized by the unit objects in the monoidal categories \mathbf M and \mathbf P. Associativity and identity laws are satisfied modulo the associators and the unitors.

If the underlying category \mathcal C is monoidal, the \mathbf{PreOp} bicategory is also monoidal, with the obvious point-wise parallel composition of pre-optics.

Triple Tambara Modules

A triple Tambara module is a functor:

T \colon \mathbf M^{op} \times \mathbf P \times \mathbf C \to \mathbf{Set}

equipped with two families of natural transformations:

\alpha \colon T \, \bold m \, \bold p \, \bold a \to T \, (\bold n \otimes \bold m) \, \bold p \, (\bold n \bullet a)

\beta \colon T \, \bold m \, \bold p \, (\bold r \bullet \bold a) \to T \, \bold m \, (\bold p \otimes \bold r) \, \bold a

and some coherence conditions. For instance, the two paths from T \, \bold m \, \bold p\, (\bold r \bullet \bold a) to T \, (\bold n \otimes \bold m)\, (\bold p \otimes \bold r) \, (\bold n \bullet \bold a) must give the same result.

One can also define natural transformations between such functors that preserve the two structures, and define a bicategory of triple Tambara modules \mathbf{TriTamb}.

As a special case, if we chose the category \mathcal P to be the trivial one-object monoidal category, we get a version of (double-) Tambara modules. If we then take the coend, P \langle a, b \rangle = \int^m T \langle m, m\rangle \langle a, b \rangle, we get regular Tambara modules.

Pre-optics themselves are an example of a triple Tambara representation. Indeed, for any fixed \bold a, we can define a mapping \alpha from the triple:

(\bold m, \bold p, \bold f \colon \bold m \bullet \bold a \to \bold p \bullet \bold b)

to the triple:

(\bold n \otimes \bold m, \bold p, \bold f' \colon (\bold n \otimes \bold m) \bullet \bold a \to \bold p \bullet (\bold n \bullet \bold b))

by lifting of \bold f by (\bold n \bullet -) and rearranging the actions using their commutativity.
Similarly for \beta, we map:

(\bold m, \bold p, \bold f \colon \bold m \bullet \bold a \to \bold p \bullet (\bold r \bullet \bold b))

to:

(\bold m , (\bold p \otimes \bold r), \bold f' \colon \bold m \bullet \bold a \to (\bold p \otimes \bold r) \bullet \bold b)

Tambara Representation

The main result is that morphisms in \mathbf {PreOp} can be expressed using triple Tambara modules. An optic:

(\bold m, \bold p, \bold f \colon \bold m \bullet \bold a \to \bold p \bullet \bold b)

is equivalent to a triple end:

\int_{\bold r \colon \mathbf P} \int_{\bold n \colon \mathbf M} \int_{T \colon \mathbf{TriTamb}} \mathbf{Set} \big(T \, \bold n \, \bold r \, \bold a, T \, (\bold m \otimes \bold n) \, (\bold r \otimes \bold p) \, \bold b \big)

Indeed, since pre-optics are themselves triple Tambara modules, we can apply the polymorphic mapping of Tambara modules to the identity optic (\bold 1, \bold 1, \bold{id} ) and get an arbitrary pre-optic.

Conversely, given an optic:

(\bold m, \bold p, \bold f \colon \bold m \bullet \bold a \to \bold p \bullet \bold b)

we can construct the polymorphic mapping of triple Tambara modules:

\begin{aligned} & T \, \bold n \, \bold r \, \bold a \xrightarrow{\alpha} T \, (\bold m \otimes \bold n) \, \bold r \, (\bold m \bullet \bold a) \xrightarrow{T \, \bold f} T \, (\bold m \otimes \bold n) \, \bold r \, (\bold p \bullet \bold b) \xrightarrow{\beta} \\ & T \, (\bold m \otimes \bold n) \, (\bold r \otimes \bold p) \, \bold b  \end{aligned}

Bibliography

  1. Brendan Fong, Michael Johnson, Lenses and Learners,
  2. Brendan Fong, David Spivak, Rémy Tuyéras, Backprop as Functor: A compositional perspective on supervised learning, 34th Annual ACM/IEEE Symposium on Logic in Computer Science (LICS) 2019, pp. 1-13, 2019.
  3. G.S.H. Cruttwell, Bruno Gavranović, Neil Ghani, Paul Wilson, Fabio Zanasi, Categorical Foundations of Gradient-Based Learning
  4. Bruno Gavranović, Compositional Deep Learning
  5. Bruno Gavranović, Fundamental Components of Deep Learning, PhD Thesis. 2024