Summary:
- I argue we can probably build ML systems that are both high-performance and much more interpretable than current systems.
- I discuss some reasons why current ML interpretability lags behind what should be possible.
- I highlight some interesting interpretability work on current systems.
- I propose a metric for estimating the interpretability of current systems and discuss how we might improve interpretability.
Definitional note: for this essay, "interpretable" serves as a catch-all for things like "we can abstractly understand the algorithms a system implements" and "we can modify internal states to usefully influence behavior in a predictable manner". Systems are more interpretable when it's easier to do things like that to them.
1: Strong, interpretable systems are possible
Imagine an alien gave you a collection of AI agents. The agents vary greatly in capability, from about as smart as a rat to human level (but not beyond). You don't know any details about the AI's internal algorithms, only that they use something sort of like neural nets. The alien follows completely inhuman design principles, and has put no effort whatsoever into making the AIs interpretable.
You are part of a larger community studying the alien AIs, and you have a few decades to tackle the problem. However, you operate under several additional constraints:
- Computation can run both forwards and backwards through the AIs (or in loops). I.e., there's no neat sequential dependency of future layers on past layers.
- Each AI is unique, and was individually trained on different datasets (which are now unavailable).
- The AIs are always running. You can't copy, backup, pause or re-run them on different inputs.
- Each AI is physically instantiated in its own hardware, with no pre-existing digital interface to their internal states.
- If you want access to internal states, you'll need to physically open the AIs up and insert electrodes directly into their hardware. This is delicate, and you also risk permanently breaking or damaging the AI with each "operation".
- Your electrodes are much larger than the AI circuits, so you can only ever get or influence a sort of average activation of the surrounding neurons/circuits.
- You are limited to 10 or fewer probe electrodes per AI.
- You are limited to the technology of the 1960s. This includes total ignorance of all advances in machine learning and statistics made since then.
Please take a moment to predict how far your interpretability research would progress.
Would you predict any of the following?
- That you'd be able to find locations in the AI's hardware which robustly correspond to the AI's internal reward signal.
- That these locations would be fairly consistent across different AIs.
- That running a current through these regions would directly increase the AIs' perceived reward.
- That the location of the electrodes would influence the nature and intensity of the reward-seeking behavior produced.
If you base your predictions on how little progress we've made in ML interpretability, these advances seem very unlikely, especially given the constraints described.
And yet, in 1953, James Olds and Peter Milner discovered that running a current through certain regions in rats' brains was rewarding and influenced their behavior. In 1963, Robert Heath showed similar results in humans. Brain reward stimulation works in every vertebrate so far tested. Since the 1960s, we've had more advances in brain interpretability:
- A recent paper was able to reversibly disable conditioned fear responses from mice.
- We’re able to cure some instances of treatment resistant depression by stimulating select neurons.
- We can (badly) reconstruct images from brain activity.
- fMRI-based lie detection techniques seem much more effective than I'd have expected (I'm very unsure about these results).
- Some studies claim accuracies as high as 90% in laboratory conditions (which I don't really believe). However, my prior expectation was that such approaches were completely useless. Anything over 60% would have surprised me.
- As far as I can tell, even relatively pessimistic studies report accuracies ~75% (review paper 1, review paper 2).
- A study using CNNs on EEG data reports 82% accuracy.
The lie detection research is particularly interesting to me, given its relation to detecting deception/misalignment in ML systems. I know that polygraphs are useless, but that lots of studies claim to show otherwise. I'm not very familiar with lie detection based on richer brain activity data such as fMRI, EEG, direct brain recordings, etc. If anyone is more familiar with this field and able to help, I'd be grateful, especially if they're able to point out some issue in the research I've linked above that reveals it to be much less impressive than it looks.
2: Why does current interpretability lag?
These brain interpretability results seem much more impressive to me than I'd have expected given how hard ML interpretability has been and the many additional challenges involved with studying brains. I can think of four explanations for such progress:
- Larger or more robust networks tend to be more interpretable.
- I think there’s some evidence for this in the “polysemantic” neurons OpenAI found in vision models. If you don’t have enough neurons to learn independent representations, you’ll need to compress different concepts into the same neurons/circuits, which makes things harder to interpret. Also, this sort of compression may cause the network issues if it has to process inputs with both concepts simultaneously. Thus, larger networks trained on a wider pool of data plausibly use fewer polysemantic neurons.
- The brain is more interpretable than an equivalently sized modern network.
- I think dropout is actually very bad for interpretability. Ideally, an interpretable network has a single neuron per layer that uniquely represents each human concept (trees, rocks, mountains, human values, etc). Networks with dropout can't allow this because then they'd completely forget 10% of those concepts whenever they process an input. The network would need to distribute the concept's representation across 3+ neurons at each layer to have a >= 99.9% chance of remembering the concept.
- Brains have very sparse connections between neurons compared to deep models. I think this forces them to learn sparser internal representations. Imagine a neuron, n, in layer k that needs access to a tree feature detector from the previous layer. If layer k-1 distributes its representation of trees across multiple neurons, n can still recover a pure tree representation thanks to the dense feed forward connection between the layers. Imagine now adding a regularization term to the training that forces neurons to make sparse connections with the previous layer. That would encourage the network to use a single neuron as its tree representation on layer k-1 because then neuron n could recover a pure tree representation with a single connection to layer k-1.
- The brain pays metabolic costs for longer running or overly widespread computations, so there's incentive to keep computations concentrated in space/time. The brain also has somewhat consistent(ish) regional specialization. Having a rough idea of what the neurons in a region are/aren't likely to be doing may help develop a more refined picture of the computations they're performing. I suspect neural architecture search may actually improve interpretability by allowing different network components to specialize in different types of problems.
- Biological neurons are more computationally sophisticated than artificial neurons. Maybe single biological neurons can represent more complex concepts. In contrast, an artificial network would need to distribute complex concepts across multiple neurons. Also, if you imagine each biological neuron as corresponding to a small, contained artificial neural network, that implies the brain is even more sparsely connected than we typically imagine.
- These brain interpretability results are much less impressive than they seem.
- This is likely part of the explanation. Doubtless, there's publication bias, cherry-picking and overselling of results. I'm also more familiar with the limitations of ML interpretability. Even so, if brain interpretability were as hard as ML interpretability, I think we'd see significantly less impressive results than we currently do.
- The news media are much more interested in brain interpretability results. Maybe that makes me more likely to learn about successes here than for ML interpretability (though I think this is unlikely, given my research focuses on ML interpretability).
- Neuroscience as a field has advantages ML interpretability doesn't.
- More total effort has gone into understanding the brain than into understanding deep networks. Maybe deep learning interpretability will improve as people become more familiar with the field?
- Neuroscience is also more focused on understanding the brain than ML is on understanding deep networks. Neuroscientists may be more likely to actually try to understand the brain, rather than preemptively giving up (as often happens in ML).
Evolution made no effort to ensure brains are interpretable. And yet, we've made startling progress in interpreting them. This suggests mind design space contains high-performance, very interpretable designs that are easily found. We should be able to do even better than evolution by explicitly aiming for interpretable systems.
3: Interpretability for current ML systems
There's a widely held preconception that deep learning systems are inherently uninterpretable. I think this belief is harmful because it provides people with an excuse for not doing the necessary interpretability legwork for new models and discourages people from investigating existing models.
I also think this belief is false (or at least exaggerated). It's important not to overstate the interpretability issues with current state of the art models. Even though the field has put roughly zero effort into making such models interpretable, we've still had some impressive progress in many areas.
- Convolutional Neural Networks
- I think many people here are already familiar with the circuits line of research at OpenAI. Though I think it's now mostly been abandoned, they made lots of interesting discoveries about the internal structures of CNN image models, such as where/how various concepts are represented, how different image features are identified and combined hierarchically, and the various algorithms implemented by model weights.
- Reinforcement Learning Agents
- Deep Mind recently published "Acquisition of Chess Knowledge in AlphaZero", studying how AlphaZero learns to play chess. They were able to identify where various human-interpretable chess concepts reside in the network as well as when the network discovers these concepts during its training.
- Transformer Language Models
- Transformer Feed-Forward Layers Are Key-Value Memories is somewhat like "circuits for transformers". It shows that attention outputs act as "keys" which search for syntactic or semantic patterns in the inputs. Then, the feed forward layer's "values" are triggered by particular keys and focus probability mass on tokens that tend to appear after the patterns that the keys detect. The paper also explores how the different layers interact with each other and the residuals to collectively update the probability distribution over the predicted token.
- Knowledge Neurons in Pretrained Transformers is able to identify particular neurons whose activations correspond to human-interpretable knowledge such as "Paris is the capital of France". The authors can partially suppress or amplify the influence such knowledge neurons have on the model's output by changing the activations of those neurons. Additionally, they can modify the knowledge in question by changing the values of the feed forward layer. E.g., make a model think "London is the capital of France".
4 Measuring/improving interpretability for current systems
Knowledge Neurons in Pretrained Transformers is particularly interesting because they present something like a metric for interpretability. Figure 3 shows the relative change in probability mass assigned to correct answers after using interpretability tools to suppress the associated knowledge (similarly, Figure 4 shows the effects of amplifying knowledge).
We could use the average of these suppression ratios as a measure for interpretability (more precisely, amenability to current interpretability techniques). Then, we could repeatedly retrain the language model with different architectures, hyper parameters, regularizers, etc, and test their effects on the model's "suppressability".
We can also test interventions aimed explicitly at improving interpretability. For example, we could compute the correlations between neuron activations, then include a term in the training loss that encourages sparse correlations. Ideally, this prompts the model to forms small, sparsely connected internal circuitry. However, it could also prompt the neurons to have more complex, second order relationships that the correlation term doesn't penalize.
Furthermore, suppressability is (as far as I can tell), a differentiable metric. It uses model gradients to attribute predictions to particular neurons, then suppresses the neurons with highest attribution for a given prediction. We can use meta-gradients to get the gradients of the attributions, then try to maximize the difference between the suppressed and normal model outputs as our suppressability training objective.
This will hopefully encourage the model to put maximum attribution on the neurons most responsible for particular pieces of knowledge, making the network more interpretable. In an ideal world, the network will also make more of its internal computation legible to the suppression technique because doing so allows the suppression technique to have a greater impact on model outputs. Of course, explicitly training for this interpretability metric may cause the network to Goodhart the metric. Ideally, we'd have additional, independent interpretability metrics to compare against.
Running interpretability experiments would also allow us to explore possible tradeoffs between interpretability and performance. I think we'll find there's a lot of room to improve interpretability with minimal performance loss. Take dropout for example: "The Implicit and Explicit Regularization Effects of Dropout" disentangles the different regularization benefits dropout provides and shows we can recover dropout's contributions by adding a regularization term to the loss and noise to the gradient updates.
We currently put very little effort into making state of the art systems interpretable. It's often possible to make large strides in under-researched areas. The surprising ease of brain interpretability suggests there are accessible ML approaches with high performance which are more interpretable than the current state of the art. Since we actually want interpretable systems, we should be able to do better still. Surely we can beat evolution at a goal it wasn't even trying to achieve?
I can only speculate, but the main researchers are now working on other stuff, like e.g. Anthropic. As to why they switched, I don't know. Maybe they were not making progress fast enough or Anthropic's mission seemed more important?
However, at least Chris Olah believes this is still a tractable and important direction, see the recent RFP by him for Open Phil.