Cool! I'm not very familiar with the paper so I don't have direct feedback on the content — seems good. But I do think I would have preferred a section at the end with your commentary / critiques of the paper, also that's potentially a good place to try and connect the paper to ideas in AI safety.
Totally agree! This is my big weakness right now - hopefully as I read more papers I'll start developing a taste and ability to critique.
Link to arxiv preprint: Do language models plan for future tokens, by Wilson Wu, John X Morris and Lionel Levine.
TLDR
As of 27 June 2024, this is ongoing work. In particular, the integer multiplication and Pythia experiments are not yet described in the current arXiv article. The authors shared a draft containing these latest results.
Also, this is my first distillation post. Any feedback - both what you like and what can be improved - will be much appreciated.
Pre-requisites
For this distillation, I assume basic familiarity with transformer architecture and gradient descent. You do not need any AI safety or mech interp experience.
Pre-caching and breadcrumbs
With the help of the diagram below, I introduce notation. x1,xi,xj (where i<j) represent input tokens, y1,yi,yjrepresent output logits, each box corresponds to one position/token of the input sequence, and the x’s inside the boxes represent the hidden states.
We have a causal mask, so the hidden states for xi are useful for the hidden states for xj but not vice versa. The question this paper asks whether this usefulness is intentional or incidental? They introduce terminology for these two possibilities:
Myopic descent
To determine how much pre-caching and breadcrumbs there is, they introduce a training scheme in which pre-caching is impossible, by zeroing the parts of the gradients that incentivize pre-caching. (By gradients here I mean the gradient of the loss w.r.t. the parameters theta.) They call this ‘myopic descent’, because it is short-sighted.
The main idea is to break up the gradient into a sum of sub-gradients,
grad[i,j]
, wheregrad[i,j]
tells you how much the loss due to yj changes if you make a small change to theta, BUT, the change to theta is only done for xi’s hidden states, not any of the other tokens.More explicitly, imagine doing a forward pass in which we replace theta with θ+δθ only for the hidden states of the i-th position: we leave θ unchanged for all other sequence positions. Because of the causal architecture, only the hidden states and outputs from the i-th position onwards will be impacted.
grad[i,j]
is the change in the loss from yj by doing this small change to theta in position i.grad[j,i]
is always zero, because hidden states for xj have no impact on yi.grad[i,i]
teach the transformer to better predict yi from xi, i.e. directly predicting the next token.grad[i,j]
teach the transformer to better predict yj from xi, i.e. helping predict future tokens. It is these gradients that result in pre-caching so it is these gradients that are zeroed!Synthetic dataset experiment
They create a synthetic numerical dataset in which pre-caching is obviously useful. For simplicity, I present a special case of the data generating process:
The idea is that calculating sin(xi) is not useful for calculating yi, but it is useful for the next 10 y’s. Hence, if we saw the transformer calculating sin(xi) in the i-th position, that means the transformer is pre-caching.
They train two transformers (with GPT2 architecture) on 30,000,000 sequences created from this process, one with vanilla training and the other with myopic descent. By doing simple investigations into the neurons (calculating correlations and using linear probes), they find strong evidence that the vanilla transformer was doing pre-caching and the myopic transformer was not. See Figures 2, 3 and 4 in the paper.
Integer multiplication experiment
They train two transformers (with GPT2 architecture) to do integer multiplication, one vanilla and one myopic. They use several tricks from Shen et al. (2023) to improve performance:
Hence, an example looks like:
3 7 0 0 * 5 4 0 0 = 5 8 2 3 0 0 0 0
Both the vanilla and myopic model are trained for one epoch on 10,000,000 examples, with at most 8 digits for each multiplicands. We see from the accuracy scores below that vanilla training performs better.
The authors hypothesize that the vanilla transformer can make use of filler tokens, as in Pfau et al (2024) where it was found that adding ellipsis ‘...’ improves performance. To test this hypothesis, they train vanilla and myopic transformers on each of two different datasets:
Looking at the accuracy scores below, we see that the vanilla transformer benefits from the padding whereas the myopic transformer suffers.
Quoting the authors:
GPT2 language experiment
I quote the paper (with redactions):
The cross entropy on a validation set for these three models is:
We see that the vanilla model does have a better score than the myopic model, but not large compared to naive bigram baseline. This suggests that pre-caching does provide some benefit but breadcrumbs are doing most of the work.
We get a more refined view of what is happening when we compute the loss on a per-position basis.
Again, just quoting the authors:
Pythia language experiments
What happens when we scale the experiments? Details of the training:
The results show that the gap in performance increases as you increase the model size. First, this is seen in the cross entropy loss:
Similar patterns are seen in the performance of the models on various benchmarks. Here are two examples:
A question from a reviewer
One of the reviewers of this post, Julian, asked whether the myopic model is disadvantaged by being trained with the same number of epochs / samples as the vanilla model. An author answered with:
I asked a follow-up about what would happen if the the myopic model was trained more, to see whether the gaps close. Their response:
Why read the pre-print
Acknowledgements
Thanks to Nicky Pochinkov, Julian Schulz and one of the authors Wilson Wu for reviewing drafts of this post. Diagrams created on bitpaper.io