(thanks to Tao Lin and Ryan Greenblatt for pointing this out, and to Arthur Conmy, Jenny Nitishinskaya, Thomas Huck, Neel Nanda, and Lawrence Chan, Ben Toner, and Chris Olah for comments, and many others for useful discussion.)
In “A Mathematical Framework for Transformer Circuits”, Elhage et al write (among similar sentences):
One layer attention-only transformers are an ensemble of bigram and “skip-trigram” (sequences of the form "A… B C") models. The bigram and skip-trigram tables can be accessed directly from the weights, without running the model.
When I first read this, I (and at least some other readers) interpreted this as a mathematical claim–that the attention layer of a one-layer transformer can be mathematically rewritten as a set of skip-trigrams, and that you can understand the models by reading these skip-trigrams off the model weights (and also reading the bigrams off the embed and unembed matrices, as described in the zero-layer transformer section – I agree with this part).
But this mathematical claim is false: One-layer transformers are more expressive than skip-trigrams, so you can’t understand them by transforming them into a set of skip-trigrams. Also, even if a particular one-layer transformer is actually only representing skip-trigrams and bigrams, you still can’t read these off the weights without reference to the data distribution.
The difference between skip-trigrams and one-layer transformers is that when attention heads attend more to one token, they attend less to another token. This means that even single attention heads can implement nonlinear interactions between tokens earlier in the context.
In this post, I’ll demonstrate that one-layer attention-only transformers are more expressive than a set of skip-trigrams, then I’ll tell an intuitive story for why I disagree with Elhage et al’s claim that one-layer attention-only transformers can be put in a form where “all parameters are contextualized and understandable”.
(Elhage et al say in a footnote, “Technically, [the attention pattern] is a function of all possible source tokens from the start to the destination token, as the softmax calculates the score for each via the QK circuit, exponentiates and then normalizes”, but they don’t refer to this fact further.)
An example of a task that is impossible for skip-trigrams but is expressible with one-layer attention-only transformers
Consider the task of predicting the 4th character from the first 3 characters in a case where there are only 4 strings:
ACQT
ADQF
BCQF
BDQT
So the strings are always:
- A or B
- C or D
- Q
- The xor of the first character being A and the second being D, encoded as T or F.
This can’t be solved with skip-trigrams
A skip-trigram (in the sense that Elhage et al are using it) looks at the current token and an earlier token and returns a logit contribution for every possible next token. That is, it’s a pattern of the form
………….X……………………Y -> Z
where you update towards or away from the next token being Z based on the fact that the current token is Y and the token X appeared at a particular location earlier in the context.
(Sometimes the term “skip-trigram” is used to include patterns where Y isn’t immediately before Z. Elhage et al are using this definition because in their context of autoregressive transformers, the kind of trigrams that you can encode involve Y and Z being neighbors.)
In the example I gave here, skip-trigrams can’t help, because the probability that the next token after Q is T is 50% after conditioning on the presence of any single earlier token.
This can be solved by a one-layer, two-headed transformer
We can solve this problem with a one-layer transformer with two heads.
The first attention head has the following behavior, when attending from the token Q (which is the only case we care about):
Token attending to | Attention score (pre-softmax) | OV behavior |
A | -10000 | 0 |
B | 10000 | 0 |
C | 0 | T |
D | 0 | F |
Q | -10000 | 0 |
So it attends almost entirely to B if B is present. If B isn’t present (because the first character was A), it will attend almost entirely to C or D. If it attends to C, it writes T; if it attends to D, it writes F. This head therefore writes the correct answer in cases where the first character was A, and writes nothing otherwise.
(By “OV behavior”, I mean W_U W_{OV} embed. So e.g. I’m saying that if you take the embedding for C, then multiply it by this head’s OV, and then unembed that, you’ll get a vector which is in the direction of the unembed for T.)
The second attention head handles the case where the first character was B:
Token attending to | Attention score (pre-softmax) | OV behavior |
A | 10000 | 0 |
B | -10000 | 0 |
C | 0 | F |
D | 0 | T |
Q | -10000 | 0 |
It’s impossible for an ensemble of skip-trigrams to learn this task, if by “ensemble of skip-trigrams” you mean “a logistic regression where all the features are of the form (token A was at position P and token B is at the current position)”, which is the most reasonable interpretation of how transformers could be considered as a set of skip-trigrams.
Proof sketch: Logistic regressions can only perfectly solve classification problems if there’s a hyperplane separating positive and negative examples. In this case, we’re only able to use the skip-trigrams that tell us whether A or B was at the first position and whether C or D was at the second position. A is only present if B isn’t present, so we only need to have one feature that represents the token at the first position; likewise for C and D. So we’re now considering the logistic regression with two features: “is A at the first position” and “is C at the second position”. It’s impossible to separate the two classes for the usual reason that you can’t learn xor with logistic regression. (Bigrams don’t help in this case because for every input, the current token is Q.) (Thanks to Paul Christiano for help with this proof.) This establishes that one-layer transformers cannot be rewritten as sets of bigrams and skip-trigrams.
In this case, the problem could have been solved if the model could express skip-quadgrams. However, we can construct similar problems that one-layer attention-only transformers can solve that skip-quadgrams can’t. (More generally, because softmax induces a nonlinear infinite-order interaction between all the attention scores, for any fixed n, one-layer attention-only transformers (with an unlimited number of heads and an unlimited vocab size) can express functions that ensembles of skip-n-grams can’t.
(One-layer attention-only transformers can express functions that can’t be expressed by skip-n-grams even if they only have a single attention head. I used multiple attention heads in this example because it allowed me to express xor, which is a particularly clean example of a function that skip-trigrams can’t model at all.)
I think that this transformer is probably pretty easy for SGD to learn–it’s not just a pathological counterexample.
IMO, for reasonable definitions of “understanding”, this falsifies Elhage et al’s claim that you can understand the one-layer transformer from its weights
Elhage et al write:
One layer attention-only transformers are an ensemble of bigram and “skip-trigram” (sequences of the form "A… B C") models. The bigram and skip-trigram tables can be accessed directly from the weights, without running the model.
[...]
By multiplying out the OV and QK circuits, we've succeeded in doing this: the neural network parameters are now simple linear or bilinear functions on tokens. The QK circuit determines which "source" token the present "destination" token attends back to and copies information from, while the OV circuit describes what the resulting effect on the "out" predictions for the next token is. Together, the three tokens involved form a "skip-trigram" of the form [source]... [destination][out], and the "out" is modified.[...]
[...] we do have transformers in a form where all parameters are contextualized and understandable. And despite these subtleties, we can simply read off skip-trigrams from the joint OV and QK matrices.
[...] It seems to us that we now understand this simplified model in the same sense that one might look at the weights of a giant linear regression and understand it, or look at a large database and understand what it means to query it. That is a kind of understanding. There's no longer any algorithmic mystery. The contextualization problem of neural network parameters has been stripped away.
I disagree that they have put their one-layer transformers into a form where all parameters are contextualized and understandable.
To me, what it means to say that a parameter is contextualized and understandable is that you can understand “what role the parameter plays” without learning more about the other parameters or the data distribution. This isn’t true in the example transformer I described above–in both of the heads, you couldn’t really understand why a single attention score had the value it had without looking at the others.
Here’s an intuitive example of how this might come up in a real language model. (I’m not saying that this is actually the best way for the language model to solve the problem I describe, but I am saying that I’m not comfortable assuming this mechanism away without empirical evidence.) Suppose your model has a head that primarily has the responsibility of looking at nouns that appear in news articles and then suggesting related nouns. This head might not do anything on sequences that aren’t news articles (because some of the skip-trigrams it implements are only valid in the context of news articles and not in e.g. Python source files, even though the skip-trigram might appear in the Python source file). It might implement this by attending strongly to Python keywords and then writing nothing. If we just tried to understand the weights of this attention head based on the skip-trigram weights, we’d totally miss the fact that this head turns off its skip-trigrams when in a Python file.
For language models in particular, is the claim that they’re a combination of bigrams and skip-trigrams empirically true?
We’ve established that you can’t always rewrite a one-layer attention-only model as an ensemble of skip-trigrams, but perhaps it’s nevertheless a fairly good approximation for the language models that we train in practice. Some REMIX participants tried to investigate this empirically. I currently don’t have a better way of summarizing their results other than “the model is somewhat but not excellently described as an ensemble of bigrams and skip-trigrams”; perhaps we’ll write something clearer about this at some point.
Is the bigrams-and-skip-trigrams approximation useful for interpretability in practice?
I have no idea. I haven’t seen any persuasive evidence either way.
Does this have any interesting or important implications?
I think the main important point here is this: I’m generally quite skeptical of approaches to interpretability which hope to eventually understand models without reference to their input distribution; it looks to me like the internals of models are intricately related to facts about the data distribution, and we should think about how to use interpretability for alignment by taking this data-distribution-dependence as a given, rather than trying to fight it.
I moderately disagree with this? I think most induction heads are at least primarily induction heads (and this points strongly at the underlying attentional features and circuits), although there may be some superposition going on. (I also think that the evidence you're providing is mostly orthogonal to this argument.)
I think if you're uncomfortable with induction heads, previous token heads (especially in larger models) are an even more crisp example of an attentional feature which appears, at least on casual inspection, to typically be monosematnically represented by attention heads. :)
As a meta point – I've left some thoughts below, but in general, I'd rather advance this dialogue by just writing future papers.
(1) The main evidence I have for thinking that induction heads (or previous token heads) are primarily implementing those attentional features is just informally looking at their behavior on lots of random dataset examples. This isn't something I've done super rigorously, but I have a pretty strong sense that this is at least "the main thing".
(2) I think there's an important distinction between "imprecisely articulating a monosemantic feature" and "a neuron/attention head is polysemantic/doing multiple things". For example, suppose I found a neuron and claimed it was a golden retriever detector. Later, it turns out that it's a U-shaped floppy ear detector which fires for several species of dogs. In that situation, I would have misunderstood something – but the misunderstanding isn't about the neuron doing multiple things, it's about having had an incorrect theory of what the thing is.
It seems to me that your post is mostly refining the hypothesis of what the induction heads you are studying are – not showing that they do lots of unrelated things.
(3) I think our paper wasn't very clear about this, but I don't think your refinements of the induction heads was unexpected. (A) Although we thought that the specific induction head in the 2L model we studied only used a single QK composition term to implement a very simple induction pattern, we always thought that induction heads could do things like match [a][b][c]. Please see the below image with a diagram from when we introduced induction heads that shows richer pattern matching, and then text which describes the k-composition for [a][b] as the "minimal way to create an induction head", and gives the QK-composition term to create an [a][b][c] matching case. (B) We also introduced induction heads as a sub-type of copying head, so them doing some general copying is also not very surprising – they're a copying head which is guided by an induction heuristic. (Just as one observes "neuron splitting" creating more and more specific features as one scales a model, I expect we get "attentional feature splitting" creating more and more precise attentional features.)
(3.A) I think it's exciting that you've been clarifying induction heads! I only wanted to bring these clarifications up here because I keep hearing it cited as evidence against the framework paper and against the idea of monosemantic structures we can understand.
(3.B) I should clarify that I do think we misunderstood the induction heads we were studying in the 2L models in the framework paper. This was due to a bug in the computation of low-rank Frobenius norms in a library I wrote. This is on a list of corrections I'm planning to make to our past papers. However, I don't think this reflects our general understanding of induction heads. The model was chosen to be (as we understood it at the time) the simplest case study of attention head composition we could find, not a representative example of induction heads.
(4) I think attention heads can exhibit superposition. The story is probably a bit different than that of normal neurons, but – drawing on intuition from toy models – I'm generally inclined to think: (a) sufficiently important attentional features will be monosemantic, given enough model capacity; (b) given a privileged basis, there's a borderline regime where important features mostly get a dedicated neuron/attention head; (c) this gradually degrades into being highly polysemantic and us not being able to understand things. (See this progression as an example that gives me intuition here.)
It's hard to distinguish "monosemantic" and "slightly polysemantic with a strong primary feature". I think it's perfectly possible that induction heads are in the slightly polysemantic regime.
(5) Without prejudice to the question of "how monosemantic are induction heads?", I do think that "mostly monosemantic" is enough to get many benefits.
(5.A) Background: I presently think of most circuit research as "case studies where we can study circuits without having resolved superposition, to help us build footholds and skills for when we have". Mostly monosemantic is a good proxy in this case.
(5.B) Mostly monosemantic features / attentional features allow us to study what features exist in a model. A good example of this is the SoLU paper – we believe many of the neurons have other features hiding in correlated small activations, but it also seems like it's revealing the most important features to us.
(5.C) Being mostly monosemantic also means that, for circuit analysis, interference with other circuits will be mild. As such, the naive circuit analysis tells you a lot about the general story (weights for other features will be proportionally smaller). For contrast, compare this to a situation where one believes they've found a neuron (say a "divisible by seven" number detector, continuing my analogy above!) and it turns out that actually, that neuron mostly does other things on a broader distribution (and they even cause stronger activations!). Now, I need to be much more worried about my understanding…