I'm confused by your notation for feed-forward layers.
What justifies re-using the same labels ("apple" etc.) for
?
If we want to express what the individual components of basis (2) mean in terms of the original space, we can either talk about which vectors/semes are mapped to them by , or which vectors/semes they get mapped to by .
But your labels don't correspond to either of these interpretations. Instead, it looks like you are following rules of the form "the 4th component of every basis is called 'yum'," which leads you to label a coordinate "yum" even though it's neither mapped from "yum" by , nor mapped to "yum" by .
This notation also seems to require the basis (2) to have the same number of elements as (1), which generally will not be the case. In transformers, (2) is typically larger by a factor of 4. The logic of your example, meanwhile, can be expressed using a smaller nonlinearity basis of 3 elements:
with some arbitrary choices about which multiplicative constants to absorb into and vs. which to absorb into .
Thanks for your comments/questions, they're very insightful.
In general, there are as many encoding spaces in a Transformer as there are computational nodes, and a traditional Transformer will have little incentive to use the same semantics for any two of the spaces. (There's a little bit of an incentive because of the residual connections, which will (I think?) kind of tie the semantics of the various hidden-size-sized embeddings spaces.)
In particular, the middle layer of the dense-relu-dense feedforward layer is usually chosen to be significantly larger (4x) than the hidden size, and so it's not even theoretically possible to represent it using the same basis. I've found that it sometimes makes sense to use anonymous seme names like x1 x2 x3 etc in the feed-forward layer for this reason. In my experience so far I've found the feed-forward layers to be most useful for conjunctions and disjunctions - and there are a quadratic number of possible conjunctions and disjunctions of even two neurons, let alone 3 or 4. So it seems to me that this might give a tiny hint as to why people have found that the intermediate embedding space of the feed-forward layer needs to be so large.
Of course, there is a potentially huge gap between what I am clever enough to think of as a use for them and what good old gradient descent is clever enough to think of. We can only easily lower-bound the potential uses of them; upper-bounding the capabilities of a component will prove much more challenging.
Re: how this interacts with Alignment Research:
I think that these ideas could prove useful in alignment research - if we understand how a language model works in excruciating detail, it seems drastically more likely that we will be able to reason about and predict various misunderstandings rooted in the ambiguity of language.
Another use is for sanity checking existing interpretability techniques. For example, to check if particular neurons identified as curve detectors via interpretability techniques were indeed curve detectors, Chris Olah spent a few hours replacing the curve-detecting neurons with handwritten curve detector neurons. (He found that the interpretability techniques were able to give qualitatively similar results for both the original neurons and the handwritten neurons. More impressively, he also found that replacing the curve detecting neurons with his handwritten neurons was able to recover ~60% of the drop in accuracy compared to removing the original neurons entirely [reported in footnote 9].)
Very nice post. It is certainly useful to do this exercise of manually encoding language rules into the weights of a transformer in order to better understand the machinery involved.
"The ultimate ambition of this work would be to go toe-to-toe with a comparably-sized Transformer model trained in the traditional way on a modern-sized data set. This might require several people-years of focused effort though."
There is a long history of attempting to parse natural language with hand design rules and heuristics. The general consensus now is that hand engineering is insufficient, and some learning from data is necessary. To me it seems that this direction inherits the problems of these old fashioned language systems since you are codifying your own hand designed heuristics and rules into the network weights.
Do you see a way to introduce learning from data without sacrificing the interpretability that your approach provides?
There are a number of ways to combine this approach with learning, but I haven't had time to try any of them yet. Some ideas I have thought of:
Previously: Teaser: Hard-coding Transformer Models
Introduction
Transformer models are incredibly powerful for natural language tasks (and they are starting to find uses in many other fields of machine learning). Unfortunately, it is nigh-impossible to interpret what goes on inside them. OR IS IT???
I have found that I can, with a fair amount of effort, hard-code the weights of a transformer model in order to perform some very crude versions of linguistic tasks. So far I have achieved English-to-French translation (on a toy corpus of about 150 sentences), text classification (is a sentence grammatical or not? on a toy corpus of a couple hundred sentences), and sentiment analysis (again on a limited corpus). These results are obviously not impressive compared to the state of the machine learning field, but I am pretty sure that they can all be drastically scaled up with the investment of some time and energy. Unfortunately, I have a fairly demanding day job, and haven't found the time and energy yet.
All of this is done by inspection (no gradient descent!). The process is a lot like programming, although it is more difficult than programming, at least right now for me. I am fairly certain that better tools and better notation can be developed to make the process easier. It is also almost certainly possible to combine hard-coding with gradient descent approaches to be able to scale these methods up in a slightly less labor-intensive way.
I think that these ideas could prove useful in alignment research - if we understand how a language model works in excruciating detail, it seems drastically more likely that we will be able to reason about and predict various misunderstandings rooted in the ambiguity of language. Given that language is (arguably) a fully general means of interacting with an artificial intelligence, it seems plausible to me that this work is on the critical path to alignment.
Doneness Status
This post is a work-in-progress. I will be editing it as I go, mostly appending more content to the end, but I will also try to fix any errors or unclear parts as I notice them or commenters point them out.
So let's hard-code some neural computation! I have a very, very messy github repository where I've done my initial experiments, if you prefer to just jump into semi-working code. Otherwise, I will do my very best to explain the ideas from scratch in this post. I'm aiming this post at anyone who's willing to put in the work to understand it, so I'll try to at least give pointers to necessary background material, of which there is a fair amount.
What Can We Do Already?
Some primitive sentiment analysis using a Vanilla RNN:
Some very simple translation with a Transformer model:
How Should We Measure Success?
Before we explain how we get results, it seems worthwhile to talk about how to measure the performance of such a system. As in traditional machine learning, this can only be measured with respect to some dataset of inputs labeled with desired outputs. For a crude metric, we can look at the fraction of inputs that receive the desired output, versus the fraction that receive some other output. We can also look at more complex metrics, such as BLEU or ROUGE for sequence-to-sequence tasks.
In traditional machine learning, performance can only be (meaningfully) measured on a holdout set that was not used to train the algorithm. This is because performance tends to be much, much higher (if you're using the right architecture and hyperparameters for your task) on the training set (the set of data that was used to train the algorithm) than it will be for the test set (the set of data that has been held out). The whole purpose and challenge of machine learning, of course, is to build models that generalize to unseen data.
A similar phenomenon occurs in this work, where data that has been examined by the programmer and run through the algorithm and used to inform updates to the rules, will typically be data that the resulting network does disproportionately well on. After all, if you miss some edge case, but then see it in your testing, you have the opportunity to fix it.
On the other hand, the programmer will presumably have a native command of at least one language, so it is at least possible for the programmer to anticipate some phenomena before seeing them in the "training data". Thus, it seems unfair to gradient descent and deep learning to compare accuracies in the low-data regime where I have been stuck so far by my lack of free time.
The ultimate ambition of this work would be to go toe-to-toe with a comparably-sized Transformer model trained in the traditional way on a modern-sized data set. This might require several people-years of focused effort though.
Some Useful Notation
The first thing we are going to do is introduce some very unconventional notation for vectors and matrices. (We won't need any more information about linear algebra than is contained in the Wikipedia article, but we will assume that you are either familiar with them or have paused and read that article.)
We will pick a set of "axes" that we will call "semes". (This word comes from semiotics, as do a few other terms we will use. I believe I'm using them in a way that is compatible with their technical meaning in semiotics, but feel free to think of this as a nonsense word that we are coining.) Each seme will be identified with a short string, often a word. So, we might have semes "wombat", "peregrine", and "pig". These play a role very similar to variable names in traditional programming, so we will generally choose them to be meaningful. Common semes that I actually use are "noun", "verb", etc.
We then will write vectors using these semes, for example ⟨⟨+pig−wombat⟩⟩ for the vector that is 1 in the direction ⟨⟨pig⟩⟩ and -1 in the direction ⟨⟨wombat⟩⟩. We can also use coefficients, so that ⟨⟨+2.1pig−3.2peregrine⟩⟩ denotes the vector that is 2.1 in the direction ⟨⟨pig⟩⟩ and -3.2 in the direction ⟨⟨peregrine⟩⟩. There are two ways of thinking about this - you can either think of the various semes as being completely orthogonal to each other, forming an orthonormal basis of whatever vector space we are in. Or you can think of them as arbitrary vectors that we are using as a (possibly overcomplete) basis. In general, both will be useful; I generally think of syntactic information as being best represented in a fully orthonormal basis, while semantic information makes much more sense as being drawn from a very overcomplete basis.
Matrices will be written in the form {{1.1pig→wombat+2.3wombat→pig−4.5pig→peregrine+0.9peregrine→peregrine}}for a matrix that would be conventionally represented as ⎡⎢⎣01.1−4.52.300000.9⎤⎥⎦.
In code, we will write them like this:
for the above vector and matrix.
As with vectors, there are two ways to think of the matrix notation. In the first way, the semes form an orthonormal basis, and we are just using them to identify which pairs of coordinates get which coefficient. But, we can also think of {{1.1pig→wombat}} as being 1.1 times the outer product of ⟨⟨pig⟩⟩ and ⟨⟨wombat⟩⟩. This second view will not be necessary for the contents of this post, but it is necessary to understand some of the ways I envision being able to combine this work with gradient descent-based learning.
It is also worth pointing out that, if we multiply matrices and vectors with the vector on the left, that {{pig→wombat}} actually maps the vector ⟨⟨pig⟩⟩ to the vector ⟨⟨wombat⟩⟩. (Although, potentially confusingly, {{1.1pig→wombat}} maps ⟨⟨pig⟩⟩ to ⟨⟨1.1wombat⟩⟩.) For this reason, we will prefer left-multiplication in our neural networks later, because it makes this particular notation way easier to think with.
Tokenization and Word Embeddings
In deep NLP, the first couple steps are about getting rid of words and replacing them with inputs that can actually be understood by a deep network. The first step is to take a string and break it into some number of discrete chunks called "tokens". In principle, we could feed things in letter-by-letter, and people have gotten semi-decent results doing that in the past, but it's a lot less labor-intensive in this context to use full words as the unit of tokenization. This is actually a mild break from most Transformer models used today, which generally make use of a "subword vocabulary" which contains a mixture of whole words and parts of words like "ing" or "particul".
Let's take an example sentence and tokenize it, just to be sure that we understand this process. Consider
Some things worth emphasizing:
Additionally, we will case-normalize our inputs by making everything lower-case. This cuts down on some repetitive work and is relatively common with deep models. We also include a special SOS (start of sentence) token and an EOS (end of sentence) token. So the above example should really look like this:
The second step in transforming a string of text into the sort of inputs that a deep network prefers is to do a "word embedding lookup". Here, each token is replaced by a fixed vector, so that we get a matrix of shape [num_tokens, word_embedding_dim]. Because the first axis (the "sequence dimension") is not semantically the same as the second axis (the "embedding dimension"), we will not use our special matrix notation, but will instead think of this as a list of vectors, one for each token.
So let's look at some word embeddings! Here are some pronouns (note that we're describing here a fragment of a flavor of English that includes the gender-neutral singular "they" in addition to the plural "they"):
Here are some verbs:
You'll note something strange about verbs (and nouns and other "content" words): they almost all have one seme that's just themselves again! What gives? These are semantic semes, which are much harder to reason about than the other semes, which are syntactic. As we said earlier, syntactic semes should be thought of as an orthonormal basis of whatever size, but semantic semes are more usefully thought of as living in a small dimensional space, where they aren't mutually orthogonal. (But! all semantic semes should be thought of as perfectly orthogonal to all syntactic semes, and vice versa.) For the time being, I restrict myself to a relatively limited vocabulary and just use the semantic semes as if they were orthogonal. For grammaticality classification, which is the domain I have worked hardest on, the semantic semes are not particularly relevant. For translation, the only really important thing is that they are able to pick out the corresponding word in the target language. (Assuming there is a straightforward single-word translation in the target language, which there has mostly been in the toy examples I have considered thus far, but which in general is not the case.)
Let's embed a sentence! Consider:
This tokenizes to
Which then embeds to
In mathematical notation, we would write this as [⟨⟨+sos⟩⟩,⟨⟨+det⟩⟩,⟨⟨+cat+sg+noun⟩⟩,⟨⟨+sit+verb+preterite+agentlack⟩⟩,⟨⟨+on+prep⟩⟩,⟨⟨+det⟩⟩,⟨⟨+mat+sg+noun⟩⟩,⟨⟨+punct+period⟩⟩,⟨⟨+eos⟩⟩]
For further clarity, let's give a gloss for each seme we're using:
Sentiment Analysis
Sentiment analysis refers to the task of extracting from a piece of natural language the overall sentiment that the speaker has towards whatever thing they're talking about. For instance, in a movie or product review, does the author recommend the movie or product? This is generally considered a pretty straightforward task for machine learning algorithms.
An extremely interpretable algorithm for sentiment analysis is given in VADER: A Parsimonious Rule-based Model for Sentiment Analysis of Social Media Text [PDF], by C. Hutto, E. Gilbert. (2014) We have implemented a similar algorithm inside of a two-layer vanilla RNN, which we will describe below. However, we wanted to first note that this algorithm is only a crude sketch of VADER, and its shortcomings should not be held against Hutto and Gilbert.
Why a vanilla RNN rather than the OMG SO MUCH BETTER LSTM? Well, vanilla RNN's are significantly simpler to understand, and their disadvantages (vanishing and exploding gradients, primarily) are only really relevant when you're actually using gradient descent! So let's do this the easy way and stick to a vanilla RNN.
First, a few preliminaries about the architecture we will be using:
Let Hℓt denote the output of layer ℓ at time-step t. Note that the superscript isn't an exponent, it's just a convenient place to put another index. There will be no exponents anywhere in this network; they are all superscripts. (Here time-step just means the index of the token. So time-step 1 will be the first token, time-step 2 will be the second token, and so on.) Hℓ0 will be a special initial state before we read any tokens, and we will set it to be the zero vector. H0t will be the output of the word-embedding layer, so it will just be the embedding of the t-th token after we look it up.
We then define the recurrence (for ℓ=1,2):
Hℓt=σ[Hℓ−1t(I+Aℓ)+Hℓt−1Bℓ+bℓ]
Here σ is the good old logistic sigmoid function, and I is the identity matrix. (Practitioners will note that the use of the identity matrix here is some sort of residual-like connection.)
We further define a pooling layer, and then a fully-connected or dense layer.
X0=meantH2t
X1=σ[X0C1+c1]
This is sufficient to generate the scores at the beginning of this post. The scores on the given examples are not all that inaccurate. Lots more work could obviously be done on this network, and I'd love it if people feel like working on this network or other later-discussed networks in the comments.
Transformer Overview
We give here a brief overview of the transformer architecture, for those unfamiliar with it. This will essentially be an accelerated recap of Jay Alammar's Illustrated Transformer, which I consider to be the best friendly introduction to the architecture.
Transformer in its full sequence-to-sequence glory has an encoder stack and a decoder stack. For text classification purposes, one generally just uses the encoder stack with a few simple layers at the end. The encoder stack is made up of a bunch of Transformer encoder blocks, each of which is the same architecturally, but each of which has its own learnable/settable weights that allow it to specialize and do its own particular task in the grand scheme of the network.
The decoder stack is also made up of a series of architecturally identical Transformer layers, again each with their own learnable/settable weights that allow them to specialize into their own unique role. The decoder layers are similar to the encoder layers, but a little bit more complex.
So now let's dive inside the Transformer layer and see how they tick!
So we have self-attention layers, encoder-decoder attention layers, and the feed-forward layers. The two types of attention layers are generally considered to be the innovative, "important" part of Transformer, but I have found in trying to hard-code weights for transformer (and years of research by many people has also found) that the feed-forward layers are crucial to being able to learn complex functions. (Well, that's not entirely true. I think someone in some paper managed to sort of smuggle the feed-forward layer into a computation that looks like self-attention over a learnable set of parameters, without losing any accuracy. But for our purposes the feed-forward layer is important.)
We'll next dive into the various layers and see how to hard-code their parameters. We'll start with the easiest layer to understand: the feed-forward layer.
Transformer Feed-Forward Layers
The standard transformer feed-forward layer can be described pretty simply as:
FFN(x)=b+ReLU(a+x⋅A)⋅BHere ReLU is a new kind of non-linearity, the "rectified linear unit". A and B are matrices called "weights", and a and b are vectors called "biases". "Parameters" refers to either weights or biases, although people will often refer to biases as "weights" also.
Like most non-linearities in deep learning, ReLU is an element-wise function, i.e., you apply it to each coordinate of the vector independently. ReLU(x)=max(x,0), so if x is negative, ReLU(x)=0, and otherwise ReLU(x)=x.
So let's hard-code a feed-forward layer! This just requires picking values for A,B,a,b.
The semantics we are trying to encode here is that ⟨⟨apple⟩⟩ OR ⟨⟨banana⟩⟩ should be mapped to ⟨⟨yum⟩⟩, while ⟨⟨cherry⟩⟩ AND ⟨⟨durian⟩⟩ should be mapped to ⟨⟨yuck⟩⟩. This results in the following output:
You could argue that ⟨⟨yuck⟩⟩ should displace ⟨⟨yum⟩⟩. This can be done, but it seems to require a third layer (I haven't given it much thought just now, so maybe it's doable in two layers), so it would have to be split across two different Transformer layers.
Preparing for Self-Attention: Positional Encodings
Before we can really dive into self-attention layers, it is useful to talk about positional encodings. This is a fairly technical aspect of the Transformer architecture, but I've found that it can be made fairly interpretable by thinking about it in the right way.
The traditional way to think about positional embeddings is just to read the following code, and then add some handwaving around "relative positional offsets can be encoded as linear combinations of sines and cosines". This is all correct, but it doesn't really yield enough understanding (for me at least) to hard-code things around the positional embeddings.
I prefer to think of the various sines and cosines as the hands of a bunch of clocks running at different speeds. If we have 512 dimensions in our positional embeddings, then there will be 256 clocks. The sine is just the y-coordinate of the clock hand, and the cosine is the x-coordinate of the clock hand.
At time t=0, we look at the first word in the sequence. All clocks point in the same direction.
At time t=1, we look at the second word in the sequence. The slowest clock (all the way on the left) has advanced a tiny bit, while the fastest clock (all the way on the right) has advanced more.
At time t=4, the slowest clock has advanced a decent amount, while the fastest clock has advanced about a quarter rotation from where it started.
Now that we have a more grounded understanding, we can ask questions like, if I have the positional encoding of "very", how do I get the positional encoding of the next word? Well, the next word is one time-step further, so each clock should advance by the amount that that particular clock advances over one time step. We can "point" to that positional embedding by using a particular angle offset for each clock, which then translates into the specific linear combination of sines and cosines referenced earlier. Thus, I have an intentionally quite redundant way to refer to a specific number of time steps in the future.
The different speeds also provide me the ability to point to ranges of time-steps relatively easily. If I want to point to a very narrow range, I can use the fastest clock, which will pick out a very specific time-step, with little room for error. If I want to refer to a broader time-range, I can use a slower clock, which will have the effect that my pointing will be somewhat evenly spread over a wide range of time-steps. Since I have so many clocks to choose from, I can be quite precise in what sort of time interval I point at.
Finally, as a technical note, it should be understood that the range of clock speeds are chosen such that no clock does a full loop. Otherwise, our pointing might accidentally point at something that's done a full loop without understanding that that is occurring.
The notion of "pointing" will prove to be a very apt way of thinking about the mechanism of self-attention. Basically, words that interact with each other in the self-attention layer can be thought of as pointing at the word they interact with.
Self-Attention
Finally, we come to the heart of the Transformer model, the self-attention layer. This can be mathematically expressed as
SA(Q,K,V)=V⋅SOFTMAX(QKT√d)where Q,K,V are matrices called the queries, key, and values, respectively. The \sqrt{d} factor, often called a "temperature", is important for gradient descent training, where it improves stability. (I would guess that it improves stability by making the attention matrix relatively balanced early on in training.) For us, we will have the opposite problem; it's much easier to think about approximately sparse matrices, so we will actually use something like this, with α being a fairly large scalar.
SA(Q,K,V)=V⋅SOFTMAX(αQKT)There's a lot to unpack here, so let's come up with a very simple example and work through it together.
Suppose we want to distinguish between these two sentences:
We will try not to go too far down the rabbit hole of what makes something grammatical or not - here we're just trying to encode the simple rule (which does have limited exceptions) that adjectives can't just hang out without modifying anything, or being used with a linking verb, or in some other way being "licensed" by the other words in the sentence.
So we would like to be able to detect when an adjective seems to be modifying a noun in the usual way, versus when it is not. In my special programming language for self-attention layers, this can be done like this:
The queries and keys live in a small dimensional embedding space that I call key-space. (Typical size in a standard transformer is 64 dimensions - much smaller than the 512/768/1024 hidden size.) There is one query vector for each token, and one key vector for each token. We take the dot product of every query vector with every key vector, and that gives us what are called "attention logits". Applying the softmax function to the attention logits gives us attention probabilities, which make up what is generally call the "attention matrix". So let's compute all of these values for the above program and our example sentences to get a sense of how all of this works.
Let's suppose that we have just run the word embedding lookup layer, so that we have the default embedding for each word, but no information yet about how the words are interacting with each other. That might look like this (omitting the positional embeddings for the time being):
The queries will then be computed as follows (again, omitting the positional components):
The keys will look like this (again, omitting the positional components):
Supposing that x1, x2, x3, x4, and x5 are orthonormal, this creates the following attention logits, where we omit zeros for brevity (and we're still omitting positional encodings):
We will continue to ignore the positional encodings for the rest of this example, since they're not needed, and don't drastically change the attention matrix. (They would be needed in the case that there are two nouns, in which case the above weights would generate a tie between red pointing to apple and red pointing to the other noun, which we would want to resolve a certain way based on the rule that, in English, an adjective is close to the noun it modifies and before it, with some limited counter-examples.)
Now let's look at the values! For simplicity in our programs, we actually combine the value projection and what is usually called the out-projection. We call the combined quantity the "interpretant" (a term from semiotics). The interpretants for us are as follows:
Multiplying the interpretants V by the attention matrix SOFTMAX(αQKT), we get the following outputs (for large α)
Using the residual connection that surrounds the self-attention layer, we then receive final outputs:
Thus, this self-attention layer has managed to compute the fact that the word "red" is licensed in the first, grammatical sentence, but not licensed in the second, questionable sentence. This can be used by downstream layers to declare the second sentence to be ungrammatical.
Phew!
More to Come!
We'll next see how to hard-code the weights of an actual Transformer. This will involve explaining the structures of Transformer layers, which will take a fair amount of time. In the meantime, please check out The Illustrated Transformer and Transformers from Scratch to get a head start on understanding them, or dive face-first into my 1800-line poorly-commented grammaticality classifier.