It’s great to see that these techniques basically work at scale, but not so much to hear that things remain messy. Do you have any intuition for whether things would start to clean up if the model was trained until the loss curve flattened out? Maybe Chinchilla-optimality even has some interesting bearing on this!
My guess is that messiness is actually a pretty inherent part of the whole thing? Models have an inherent reason to want to do the problem with a single clean solution, if they can simultaneously use the features "nth item in the list" and "labelled A" and even "has two incorrect answers before it" why not?
During parts of the project I had the hunch that some letter specialized heads are more like proto-correct-letter-heads (see paper for details), based on their attention pattern. We never investigated this, and I think it could go either way. The "it becomes cleaner" intuition basically relies on stuff like the grokking work and other work showing representations being refined late during training by.. Thisby et al. I believe (and maybe other work). However some of this would probably require randomising e.g. the labels the model sees during training. See e.g. Cammarata et al. Understanding RL Vision: If you only ever see the second choice be labeled with B you don't have an incentive to distinguish between "look for B" and "look for the second choice". Lastly, even in the limit of infinite training data you still have limited model capacity and so will likely use a distributed representation in some way, but maybe you could at least get human interpretable features even if they are distributed.
Cross-posting a paper from the Google DeepMind mech interp team, by: Tom Lieberum, Matthew Rahtz, János Kramár, Neel Nanda, Geoffrey Irving, Rohin Shah, Vladimir Mikulik
Informal TLDR
See Tom's and my Twitter summaries for more. Note that I (Neel) am cross-posting this on behalf of the team, and neither a main research contributor nor main advisor for the project.
Key Figures
An overview of the weird kinds of heads found, like the "attend to B if it is correct" head!
The losses under different mutations of the letters - experiments to track down exactly which features were used. Eg replacing the labels with random letters or numbers preserves the "nth item in the list" feature while shuffling ABCD lets us track the "line labelled B" feature
The queries and keys of a crucial correct letter head - it's so linearly separable! We can near loss-lessly compress it to just 3 dimensions and interpret just those three dimensions. See an interactive 3D plot here
Abstract
Read the full paper here: https://arxiv.org/abs/2307.09458