I expect to have more detailed thoughts worth sharing as I spend more time with this content, but one thing stands out brightly as a first: This is, head-and-shoulders, the best language model interpretability work to date. I'm impressed at the thoroughness of the theory combined with detailed real examples.
This also seems like a good motivation to go back and study layer reordering (a'la Sandwich Transformers) as a treatment affecting the induced circuits of a model.
(h/t Kevin Wang for pointing out the sandwich transformer paper to me recently)
Thanks for the writeup! The first paper covers the first half of the video series, more or less. I've been working on a second paper which will focus primarily on the induction bump phenomenon (and other things described in the second half of the video series), so much more to come there!
This looks really interesting. Is there any intention to use these insights to design even more interpretable models than transformers? I've had the feeling that transformer models may be too general-purpose for their own good, in terms of training efficiency and interpretability. By that I mean that, just like fully connected neural networks technically have at least as much computational/representational power as convolutional neural networks, yet they are much harder to train for general image processing than their more constrained counterparts that take full advantage of translational equivariance, transformer-type language models might not have enough constraints to make them efficient enough for AGI.
In these models, some representation of every token is compared against a representation of every other token encountered so far, which gives quadratic complexity for every attention layer at runtime. This then leads to further transformation of the data after each attention block, creating what is effectively a new string of abstract tokens, each of which is some hard-to-interpret combination of the token representations in the level below. The only information added to the vector representation of each token, as far as I understand it, is some vector representing the relative position of the tokens within the string (which itself necessitates a special type of normalization step later on). Otherwise, it's up to the model to learn to assign implicit roles/functions to each token through the attention module. This hides away the information of what each token is doing, which a more constrained model could instead represent explicitly.
It seems to me that we could do better. For instance, suppose we had a model that had "slots" (I'm thinking something like CPU registers) that it would fill in with token vectors as it went along. The LM would learn to assign functions like "subject", "verb", "direct object modifier", etc., with one part learning which tokens should get routed to which slots, another part learning to predict which slot (e.g., part of speech) will get routed to next based on what information has already been filled in and on the learned rules of syntax, and another part predicting what information should go into the empty slots (allowing it to "read between the lines"). That last part could also be hooked up to a long-memory database of learned relations that it could fill in and update as it accumulates training data (something like what DeepMind published recently: https://deepmind.com/research/publications/2021/improving-language-models-by-retrieving-from-trillions-of-tokens).
Although the role of each slot will be arbitrary and assigned only through training, I think this type of architecture would make it easier to extract semantic roles for the tokens that it reads in, since these semantic roles have explicit locations where they can always be found. In other words, you can use the same method to find out what the LM thinks about the who, what, when, where, why, and how of what it reads or says (along with what it thinks about everything it doesn't read or say by looking into the "unused" slots). With transformers, this would be much more difficult, since semantic roles are assigned much more implicitly and a lot could be hiding in the weights.
That was just an idea, but I think that intepretibility will come more easily the more we constrain the language model with both functional and representational modularity. Perhaps the work you do could help inform what sorts of constraints would be most effective to that end.
Chris Olah, Neel Nanda, Catherine Olsson, Nelson Elhage, and a bunch of other people at Anthropic just published “Transformer Circuits,” an application of the Circuits-style interpretability paradigm to transformer-based language models. From their very top level summary:
They've chosen to publish their work in an interestingly novel format, publishing their first paper, “A Mathematical Framework for Transformer Circuits,” alongside a set of YouTube videos that go into even more detail on their findings. I watched the full YouTube playlist, and I found it absolutely fascinating and would highly recommend it as a way to engage with this research.
Some of my high-level takeaways:
Overall, I think this is clearly the most exciting progress in transparency and interpretability in general since Circuits—and I'm really happy to see it happening in language models, as I've previously emphasized I think is important for us to focus on. One thing that I think really sets this sort of transparency and interpretability work apart from the rest is the authors' emphasis on understanding the mechanistic building blocks underlying their models—with the hope of eventually being able to reverse-engineer them—rather than just, for example, trying to give humans tools to predict what models are doing (without the output of those tools necessarily having any correspondence to what the models are actually doing).
Some possible future research directions related to understanding how induction heads compose:
One thing I'm still struggling to understand related to the induction bump is how it can be the case that, after the induction bump, large models don't do relatively better at meta-learning compared to small models. I found that extremely surprising and I feel like I almost don't believe that that can be the end of the story—qualitatively, we certainly observe much more interesting meta-learning in larger models, and it seems really strange for all of that to just be reflected in an overall loss decrease rather than an increase in the amount of meta-learning. At the very least, I feel like this fact deserves some sort of further explanation—perhaps there are other interesting phase changes hiding in later parts of the loss function that might help explain what's going on here, or perhaps the claim that the whole phenomenon of “meta-learning” is just task recognition could help shed some light on this result. ↩︎