This post provides background, motivation, and a nontechnical summary of the purely mathematical https://arxiv.org/abs/2310.06686.
Coauthors (alphabetical): Chris MacLeod, Jenny Nitishinskaya, Buck Shlegeris. Work done mostly while at Redwood Research. Thanks to Joe Benton and Ryan Greenblatt for some math done previously. Thanks to Neel Nanda, Fabien Roger, Nix Goldowsky-Dill, and Jacob Hilton for feedback on various parts of this work.
Intro
In interpretability (and more generally in model understanding or model neuroscience) people care about measuring the effect on the model’s behavior from multiple inputs or components[1] (such as heads) and identifying which ones are important. This is called attribution.
Suppose we’ve done attribution to two different parts of the model. Intuitively, something very different is going on if these two parts are also importantly interacting than if they aren’t! In this post we consider the question: what is a principled interpretability framework for attributing to the interaction between inputs or components?
Summary
We can decompose a function into a sum of all the input interaction terms of various orders: the mean of the function, plus the individual contributions of each input, plus the second-order interaction of every pair of inputs, etc. This is the Generalized [Cumulant/Wick] Product Decomposition (G[C/W]PD).
Attribution to one input at a time is, in general, not enough to explain a function’s behavior.
If you aren’t measuring interactions, notice that you are assuming they are 0!
Recall that we have a way to do attribution to model inputs (or components): tweak 1 part of the input while keeping the others the same. For example, to see how much a token in the input mattered, we can ablate that token and see how the model’s output changes.
In this post we are going to be talking in terms of resample ablation and taking expectations of the output over some distribution of your choice: for more info on why, see the causal scrubbing writeup.
FAQ
What’s the relevance for alignment?
NB this section is brief and elides a lot of detail, but this post felt incomplete without mentioning why the authors were interested in this direction.
We expect that when aligning superhuman systems, we may need models to do some amount of generalization and handle some distribution shifts, while some anomalies/shifts may be unsafe (we’re including thinking about mechanistic anomalies as well as input/output anomalies).
Suppose we performed attribution to the inputs of our model on a thoroughly-supervised distribution. Now we are running it in deployment, and want to tell if the model’s behavior on new data is safe. We again run our attribution calculation, and find that it is a bit different (perhaps a bit higher to one input and a bit lower to another). How do we tell whether this is okay? One way might be to measure whether there is novel interaction: this would be something qualitatively different from what we had seen before.
One example where interaction specifically may be important is for detecting collusion: we’re often thinking of the ELK setup, where the model intentionally deceiving the oversight process manifests as a surprising interaction between components of the model reasoning about the outcomes we will oversee.
What’s the relevance for interpretability more broadly?
A precise, principled framework for the terms that might be important for a model behavior is useful for crisply stating interpretability hypotheses. This is great for comparing results of different experiments; as well as for automated hypothesis search, reducing the need for experimenter judgement in which additional attributions should be measured.
We think this framework is also just healthy to have in mind. When investigating your model behavior of interest, it’s important to remember that “the action” doesn’t have to flow through any one input you are considering. This is most obvious if you search over all [heads, input tokens, etc.] and don’t find a responsible one. In other cases, you might find some effect but miss some important pieces. If you are not already thinking of interactions as potentially important, it can be harder to notice what you are missing.
Why should I care if I don’t think interactions are likely to matter, on priors?
There is generally[2]some interaction of the multiple “bits” you are considering. If you think interactions don’t matter and you want to be thorough, then you should check that these are small.
Do interactions actually show up in practice?
Example of interaction: redundancy
It’s obvious that multiple inputs to a computation may be redundant for the task at hand: resample-ablating (or masking out, or any other ablation) any one of them has no effect, but ablating all of them would have a large effect! Consider classifying whether a sentence is in French: if you replace just one word with a word from an english sentence, the classifier will still say it is in French with very high probability (assuming the sentence is long enough). The different inputs (tokens) are redundant for the model behavior.
Model components can also be redundant: small transformers exhibit multiple heads that substantially copy the previous token into a different subspace, while the interpretability in the wild paper (from here on referred to as the IOI paper) showed redundant “name mover heads”. In both cases, many heads at least partially do the same job.
Example of interaction: qualitatively different behavior
In other cases, the interaction between inputs may be more complex, where the response to one is conditional on another. A basic example is XOR: if the input is (1, 0), then the attribution to y is positive (changing y would decrease the output) while if the input is (0, 0) then the attribution to y is negative!
In LLMs, one example is backup name mover heads from the IOI paper: these seem to perform the name-mover task only when the “primary” name mover heads are not performing it!
There are so many interactions, measuring all of them would be really expensive. Can I cheaply check for interactions without being quite so rigorous?
It’s sometimes possible to estimate interaction without computing it explicitly. For example, suppose you identified previous-token heads by e.g. examining attention patterns. You could ablate a set of these heads and see if the resulting change in the output is equal to the sum of the changes when ablating a single one at a time. If it is, then either there are no interactions between them, or all the interactions (approximately) cancel out. If it’s not, then there is some interaction between the heads, though you don’t know which ones.
In the IOI paper, the authors didn’t measure the backup-name-mover/name-mover interaction explicitly: instead they performed some experiments[3] that showed that there was some interaction.
We’re excited about the principled framework we present and its applications, but if you don’t wish to adopt it, we hope you are still aware of interaction effects and know to estimate them.
Intuition and overview of attribution framework
Let’s review how we attribute to a particular input x0 to the output of a function f (in expectation over performing resample ablation). We can think of it as follows:
The amount that x0 matters = however much of the value of f was not explained by taking into account the general behavior of f on the input distribution X
Let’s consider some simple cases. If f is a constant function, the attribution to x0 is 0. If it’s the identity, then the attribution to x0 is just how extremal x0 is with respect to the distribution: x0−μ(X).
Now suppose f is a function of two variables, x and y. We have two inputs, x0 and y0, which happen to be redundant for a computation (such as two words in a French sentence that f is classifying the language of). The experiment to do here is obvious—ablate both of them and see how the output of f changes—but how do we quantify the irreducible amount the interaction matters?
The amount the interaction of x0 and y0 matters = however much of the value of f was not explained by taking into account: the general behavior of f on the input distribution (X,Y) the general behavior of f conditional on x0: what would you expect the output of f to be, knowing nothing about y0? the general behavior of f conditional on y0
Again, if f is a constant function, any attribution (including to the interaction of x0 and y0) is 0. If it’s linear, e.g. f(x,y)=x+y, then we expect this attribution should be 0 as well: there is nothing interesting about the combination of x0 and y0.
A worked example
We’ll work out the math for the two-variable function f(x,y)=x∗y. Recall that we have inputs x0,y0 and want to attribute to parts of this input.
We could resample-ablate the entire input to contextualize it in the dataset:
f(x0,y0)−E(X,Y)f(x,y)=x0y0−E(X,Y)[xy]
This is just like the single-input attribution in the previous section: we’re measuring how extremal the value of f on this input is, compared to its average value.
We could resample-ablate just x to see how much x0 mattered:
f(x0,y0)−EXf(x,y0)=x0y0−EX[x]y0
Note that if y0 is 0, the above expression is 0. This makes sense: at that y0, x does not matter at all.
We could also ask the above for the averagey0:
EY[f(x0,y)−EXf(x,y)]=x0EY[y]−EX[x]EY[y]
What about the interaction of x0 and y0? Recall we said that the amount the interaction matters is:
however much of the value of f was not explained by: the baseline average of f over all inputs how much x0 matters for the average y how much y0 matters for the average x i.e.:
This known as the Wick product. The last form we’ve written this in is quite intuitive: how much are x0 and y0 “more together” than you would expect from the covariance of the underlying distributions?
But nothing here depended on f computing the product! We can compute the same thing for any f with two (or more!) inputs.
We can see that if f is linear, this attribution is 0. This is what we intuitively expected!
In the example above, the interaction was the missing piece needed to fully describe the behavior of f. That is, if we denote an attribution term with ωf,[5] then
f=ωf,{X,Y}+EY[ωf,{X}]+EX[ωf,{Y}]+EX,Yf.
We can think of attributing to x0 as a term in a decomposition of f, and a hypothesis that some inputs or interactions between them don’t matter as a statement that neglecting them is a good approximation to f.
From this expression we can clearly see that the terms corresponding to separate inputs, or even all the inputs together, are not all the terms needed to describe the model’s behavior.
Our contribution
We’ve argued that interaction terms can be important, and how we should measure them. What would you do to use this in practice?
In the linked arXiv post, we have
defined the general formula for this attribution for arbitrary numbers of inputs and interaction orders
provided additional intuition
proven some nice properties
provided some sample code[6] for those who prefer that over formulas
In this post, we talked about attribution at a fixed reference input (x0,y0). In the linked writeup, we also cover measuring the interaction of X and Y on average:
Kf(X,Y):=E(X,Y)f−EXEYf
Note this is a generalization of the notion of covariance: if f is just the product function, this is the covariance between X and Y. We call Kf the Generalized Cumulant Product (GCP). We can write the expectation of f as a sum of GCPs, and this form is the Generalized Cumulant Product Decomposition (GCPD):
Note that components can be seen as just multiple inputs to a treefied model. e.g. as in the IOI paper. We’ll mostly talk about attributing to inputs for ease of language.
In the notation of our paper, they computed something like E[EY(f−EXf)−(f−EXf)]=Kf(X,Y). They found this was large, i.e. the generalized-covariance between X (the input to the name-mover heads) and Y (the input to the backup-name-mover heads) is large. Though, they performed resample ablation on X and knockout on Y.
We call this the Generalized Wick Product (GWP) and the form of f the Generalized Wick Product Decomposition (GWPD) (though technically the terms are expectations of GWPs).
Probably not performant for high-order terms, for which memoization would be helpful. But the second-order interaction is easy to compute and you can probably do it today.
This post provides background, motivation, and a nontechnical summary of the purely mathematical https://arxiv.org/abs/2310.06686.
Coauthors (alphabetical): Chris MacLeod, Jenny Nitishinskaya, Buck Shlegeris. Work done mostly while at Redwood Research. Thanks to Joe Benton and Ryan Greenblatt for some math done previously. Thanks to Neel Nanda, Fabien Roger, Nix Goldowsky-Dill, and Jacob Hilton for feedback on various parts of this work.
Intro
In interpretability (and more generally in model understanding or model neuroscience) people care about measuring the effect on the model’s behavior from multiple inputs or components[1] (such as heads) and identifying which ones are important. This is called attribution.
Suppose we’ve done attribution to two different parts of the model. Intuitively, something very different is going on if these two parts are also importantly interacting than if they aren’t! In this post we consider the question: what is a principled interpretability framework for attributing to the interaction between inputs or components?
Summary
Background: attribution via interventions
Recall that we have a way to do attribution to model inputs (or components): tweak 1 part of the input while keeping the others the same. For example, to see how much a token in the input mattered, we can ablate that token and see how the model’s output changes.
In this post we are going to be talking in terms of resample ablation and taking expectations of the output over some distribution of your choice: for more info on why, see the causal scrubbing writeup.
FAQ
What’s the relevance for alignment?
NB this section is brief and elides a lot of detail, but this post felt incomplete without mentioning why the authors were interested in this direction.
We expect that when aligning superhuman systems, we may need models to do some amount of generalization and handle some distribution shifts, while some anomalies/shifts may be unsafe (we’re including thinking about mechanistic anomalies as well as input/output anomalies).
Suppose we performed attribution to the inputs of our model on a thoroughly-supervised distribution. Now we are running it in deployment, and want to tell if the model’s behavior on new data is safe. We again run our attribution calculation, and find that it is a bit different (perhaps a bit higher to one input and a bit lower to another). How do we tell whether this is okay? One way might be to measure whether there is novel interaction: this would be something qualitatively different from what we had seen before.
One example where interaction specifically may be important is for detecting collusion: we’re often thinking of the ELK setup, where the model intentionally deceiving the oversight process manifests as a surprising interaction between components of the model reasoning about the outcomes we will oversee.
What’s the relevance for interpretability more broadly?
A precise, principled framework for the terms that might be important for a model behavior is useful for crisply stating interpretability hypotheses. This is great for comparing results of different experiments; as well as for automated hypothesis search, reducing the need for experimenter judgement in which additional attributions should be measured.
We think this framework is also just healthy to have in mind. When investigating your model behavior of interest, it’s important to remember that “the action” doesn’t have to flow through any one input you are considering. This is most obvious if you search over all [heads, input tokens, etc.] and don’t find a responsible one. In other cases, you might find some effect but miss some important pieces. If you are not already thinking of interactions as potentially important, it can be harder to notice what you are missing.
Why should I care if I don’t think interactions are likely to matter, on priors?
There is generally[2] some interaction of the multiple “bits” you are considering. If you think interactions don’t matter and you want to be thorough, then you should check that these are small.
Do interactions actually show up in practice?
Example of interaction: redundancy
It’s obvious that multiple inputs to a computation may be redundant for the task at hand: resample-ablating (or masking out, or any other ablation) any one of them has no effect, but ablating all of them would have a large effect! Consider classifying whether a sentence is in French: if you replace just one word with a word from an english sentence, the classifier will still say it is in French with very high probability (assuming the sentence is long enough). The different inputs (tokens) are redundant for the model behavior.
Model components can also be redundant: small transformers exhibit multiple heads that substantially copy the previous token into a different subspace, while the interpretability in the wild paper (from here on referred to as the IOI paper) showed redundant “name mover heads”. In both cases, many heads at least partially do the same job.
Example of interaction: qualitatively different behavior
In other cases, the interaction between inputs may be more complex, where the response to one is conditional on another. A basic example is XOR: if the input is (1, 0), then the attribution to y is positive (changing y would decrease the output) while if the input is (0, 0) then the attribution to y is negative!
In LLMs, one example is backup name mover heads from the IOI paper: these seem to perform the name-mover task only when the “primary” name mover heads are not performing it!
There are so many interactions, measuring all of them would be really expensive. Can I cheaply check for interactions without being quite so rigorous?
It’s sometimes possible to estimate interaction without computing it explicitly. For example, suppose you identified previous-token heads by e.g. examining attention patterns. You could ablate a set of these heads and see if the resulting change in the output is equal to the sum of the changes when ablating a single one at a time. If it is, then either there are no interactions between them, or all the interactions (approximately) cancel out. If it’s not, then there is some interaction between the heads, though you don’t know which ones.
In the IOI paper, the authors didn’t measure the backup-name-mover/name-mover interaction explicitly: instead they performed some experiments[3] that showed that there was some interaction.
We’re excited about the principled framework we present and its applications, but if you don’t wish to adopt it, we hope you are still aware of interaction effects and know to estimate them.
Intuition and overview of attribution framework
Let’s review how we attribute to a particular input x0 to the output of a function f (in expectation over performing resample ablation). We can think of it as follows:
Let’s consider some simple cases. If f is a constant function, the attribution to x0 is 0. If it’s the identity, then the attribution to x0 is just how extremal x0 is with respect to the distribution: x0−μ(X).
Now suppose f is a function of two variables, x and y. We have two inputs, x0 and y0, which happen to be redundant for a computation (such as two words in a French sentence that f is classifying the language of). The experiment to do here is obvious—ablate both of them and see how the output of f changes—but how do we quantify the irreducible amount the interaction matters?
Again, if f is a constant function, any attribution (including to the interaction of x0 and y0) is 0. If it’s linear, e.g. f(x,y)=x+y, then we expect this attribution should be 0 as well: there is nothing interesting about the combination of x0 and y0.
A worked example
We’ll work out the math for the two-variable function f(x,y)=x∗y. Recall that we have inputs x0,y0 and want to attribute to parts of this input.
We could resample-ablate the entire input to contextualize it in the dataset:
f(x0,y0)−E(X,Y)f(x,y)=x0y0−E(X,Y)[xy]This is just like the single-input attribution in the previous section: we’re measuring how extremal the value of f on this input is, compared to its average value.
We could resample-ablate just x to see how much x0 mattered:
f(x0,y0)−EXf(x,y0)=x0y0−EX[x]y0Note that if y0 is 0, the above expression is 0. This makes sense: at that y0, x does not matter at all.
We could also ask the above for the average y0:
EY[f(x0,y)−EXf(x,y)]=x0EY[y]−EX[x]EY[y]What about the interaction of x0 and y0? Recall we said that the amount the interaction matters is:
This known as the Wick product. The last form we’ve written this in is quite intuitive: how much are x0 and y0 “more together” than you would expect from the covariance of the underlying distributions?
But nothing here depended on f computing the product! We can compute the same thing for any f with two (or more!) inputs.
We can see that if f is linear, this attribution is 0. This is what we intuitively expected!
Attribution as function approximation[4]
In the example above, the interaction was the missing piece needed to fully describe the behavior of f. That is, if we denote an attribution term with ωf,[5] then
f=ωf,{X,Y}+EY[ωf,{X}]+EX[ωf,{Y}]+EX,Yf.We can think of attributing to x0 as a term in a decomposition of f, and a hypothesis that some inputs or interactions between them don’t matter as a statement that neglecting them is a good approximation to f.
From this expression we can clearly see that the terms corresponding to separate inputs, or even all the inputs together, are not all the terms needed to describe the model’s behavior.
Our contribution
We’ve argued that interaction terms can be important, and how we should measure them. What would you do to use this in practice?
In the linked arXiv post, we have
Appendix
Future post: cumulant propagation
We can translate ARC’s cumulant propagation algorithm on arithmetic circuits into computing a set of attributions. Maybe we’ll write this up.
Average interaction
In this post, we talked about attribution at a fixed reference input (x0,y0). In the linked writeup, we also cover measuring the interaction of X and Y on average:
Kf(X,Y):=E(X,Y)f−EXEYfNote this is a generalization of the notion of covariance: if f is just the product function, this is the covariance between X and Y. We call Kf the Generalized Cumulant Product (GCP). We can write the expectation of f as a sum of GCPs, and this form is the Generalized Cumulant Product Decomposition (GCPD):
E(X,Y)f=Kf(X,Y)+Kf(X|Y)Note that components can be seen as just multiple inputs to a treefied model. e.g. as in the IOI paper. We’ll mostly talk about attributing to inputs for ease of language.
The interaction is always 0 if your model is completely linear, or otherwise has no cross-terms.
In the notation of our paper, they computed something like E[EY(f−EXf)−(f−EXf)]=Kf(X,Y). They found this was large, i.e. the generalized-covariance between X (the input to the name-mover heads) and Y (the input to the backup-name-mover heads) is large. Though, they performed resample ablation on X and knockout on Y.
A Mathematical Framework for Transformer Circuits and Formalizing the presumption of independence similarly break up functions into a sum of terms.
We call this the Generalized Wick Product (GWP) and the form of f the Generalized Wick Product Decomposition (GWPD) (though technically the terms are expectations of GWPs).
Probably not performant for high-order terms, for which memoization would be helpful. But the second-order interaction is easy to compute and you can probably do it today.