Thanks for this! One thing I don't understand about influence functions is: why should I care about the proximal Bregman objective? To interpret a model, I'm really interested in in the LOO retraining, right? Can we still say things like "it seems that the model relied on this training sample for producing this output" with the PBO interpretation?
I agree that approximating the PBO makes this method more lossy (not all interesting generalization phenomena can be found). However, I think we can still glean useful information about generalization by considering "retraining" from a point closer to the final model than random initialization. The downside is if, for example, some data was instrumental in causing a phase transition at some point in training, this will not be captured by the PBO approximation.
Indeed, the paper concedes:
Influence functions are approximating the sensitivity to the training set locally around the final weights and might not capture nonlinear training phenomena
Purely empirically, I think Anthropic's results indicate there are useful things that can be learnt, even via this local approximation:
One of the most consistent patterns we have observed is that the influential sequences reflect increasingly sophisticated patterns of generalization as the model scale increases. While the influential sequences for smaller models tend to have short overlapping sequences of tokens, the top sequences for larger models are related at a more abstract thematic level, and the influence patterns show increasing robustness to stylistic changes, including the language.
My intuition here is that even if we are not exactly measuring the counterfactual "what if this datum was not included in the training corpus?", we could be estimating "what type of useful information is the model extracting from training data that looks like this?".
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image.
It's pbzcnevat gb gur arnerfg cbvagf ba gur obhaqnel bs qvtvg-pyhfgref! Bonus points if you made your observation without that interpretation in mind. What if you do Jacobian regularization?
Thank you for this. How would you think about the pros/cons of influence functions vs activation patching or direct logit attribution in terms of localizing a behavior in the model?
Anthropic recently published the paper Studying Large Language Model Generalization with Influence Functions, which describes a scalable technique for measuring which training examples were most influential for a particular set of weights/outputs of a trained model. This can help us better understand model generalization, offering insights into the emergent properties of AI systems. For instance, influence functions could help us answer questions like "is the model asking not to be shut down because it has generalized that this is a generically good strategy for pursuing some goal, or simply because texts where AIs ask not to be shut down are commonly found in the training corpus?".
In this post, I aim to summarize the approximations used in the paper to calculate the influence of different training examples and outline how the approximations can be implemented in PyTorch to form the basis of further research on influence functions by the AI safety community.
(Note: most formulae are copied or adapted from the original paper, with a few additional derivation steps / simplified notation used for clarity in some places.)
Deriving the exact form of the influence function
Before we go into approximations, it is necessary to understand what specifically we are trying to measure.
Given some element zm=(xm,ym) of a dataset D={zi}Ni=1, we define the response function as the optimal solution θ∗ (weights that minimize expected loss L) as a function of the weighting ϵ of this example.
θ∗(ϵ)=argminθ∈RD1N∑L(zi,θ)+ϵL(zm,θ)
We define the influence Iθ∗(zm) of zm on θ∗ using the first-order Taylor approximation to the response function at ϵ=0.
Δθ=θ∗(ϵ)−θ∗(0)≈∂θ∗(ϵ)∂ϵ|ϵ=0⋅ϵ=Iθ∗(zm)⋅ϵ
We can get ∂θ∗(ϵ)∂ϵ the following way:
We know θ∗ is a minimum of 1N∑L(zi,θ)+ϵL(zm,θ) and so the gradient wrt θ is zero at that point
∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=0
Differentiating each side wrt ϵ:
(The LHS both directly depends on ϵ, and indirectly via θ∗, so we use the Implicit Function Theorem u(x)=f(x,g(x))=0→ dudx=∂f∂g∂g∂x+∂f∂x)
∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))⋅∂θ∗∂ϵ+∂∂ϵ∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=0
The second term can be simplified: ∂∂ϵ∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=∇θL(zm,θ∗)
And so we can rearrange to get an expression for ∂θ∗∂ϵ:
∂θ∗∂ϵ=−∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))−1∇θL(zm,θ∗)
This tells us how the optimal parameters change with a perturbation ϵ to the weighting of an added data point zm. The change is proportional to the negative product of the inverse Hessian of the loss on all the data and the gradient of the loss on the data point in question with respect to the model parameters (both evaluated at the optimal parameters).
For simplicity, as in the paper, we'll denote ∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗)) as H.
Therefore, Iθ∗(zm)≈−H−1∇θL(zm,θ∗) (This corresponds to Equation 3 in the paper).
Influence on some function of the model weights
So far, we have derived an expression for the influence of an added data point on the parameters θ. However, we are more interested in the influence of particular data points on some measurable properties of the model, such as the output logits or validation loss. We can see this as some function f(θ∗) of the trained parameters.
By the chain rule ∂f(θ∗(ϵ))∂ϵ=∇θf(θ∗)∂θ∗(ϵ)∂ϵ and so
If(zm)≈−∇θf(θ∗)TH−1∇θL(zm,θ∗) (This corresponds to Equation 5 in the paper).
Problems with this expression
The paper mentions that because of these problems, "past works have found influence functions to be inaccurate for modern neural networks."
How do we fix this?
One approach is to define a new objective that:
This is what the proximal Bregman objective (PBO) attempts to define.
θs(ϵ)=argminθ∈RD1N∑(L(zi,θ)−L(zi,θs)−∇yiL(zi,θs)T(yi−ysi))
+ϵL(zm,θ)+λ2||θ−θs||2
(yi here is the output of the model at the parameters θ on input xi, and ysi is the output of the model at parameters θs on input xi)
The PBO basically introduces a penalty for diverging too far from the initialized parameters, so there is some defined optimum that balances moving too far from the parameters at initialization and achieving good loss.
So we can redefine the gradients used in If in terms of this new loss function that considers both the loss given a new training data point and the divergence from current parameters.
From Bae et al.'s 2022 paper If Influence Functions are the Answer, Then What is the Question?:
Applying the Implicit Function Theorem to the PBO, we can obtain an influence function with respect to the PBO objective (Equation 9 in the paper):
Iθs(zm)=dθsdϵ=−(G+λI)−1∇θL(zm,θs)
Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs)
Where G is the Gauss-Newton Hessian G=E[JTHyJ]. J is the Jacobian - the first derivative of the network's outputs with respect to the parameters, and Hy is the Hessian of the loss with respect to the network's outputs.
Efficient calculation
So, we want to get −∇θf(θs)T(G+λI)−1∇θL(zm,θs).
Let's assume we have the following:
The key ingredients needed to calculate If are:
Notice that only key ingredient 1 depends on the property of interest. We can pre-compute ingredients 2 and 3 and then use this to test a bunch of different properties (for example, find the most influential training examples for a bunch of different model input-output pairs).
We can also calculate the influence as a batched operation over many training data points (batch ∇θL(zm,θs) over multiple zm's) to increase efficiency via vectorization.
Which leaves the final key question: how do we get (G+λI)−1v?
Kronecker-Factored Approximate Curvature (KFAC)
Originally introduced in the 2015 paper Optimizing Neural Networks with Kronecker-factored Approximate Curvature by Martens and Grosse, KFAC is an approximation to the Fischer information matrix (FIM) =Ex∼p(x),y∼P(y|x;θ)[∇θlogp(y|x;θ)∇θlogp(y|x;θ)⊤] that can be inverted very efficiently. In the case of many models where the loss is given by the negative log probability associated with a simple predictive distribution, the FIM is equal to the Gauss-Newton Hessian.
KFAC for MLP models involves the following:
Given a fully connected model with L layers, let's assume each layer l's output is:
al=ϕl(Wlal−1)
where al−1∈RM, Wl∈RP×M and ϕl is a nonlinear activation function. [2]
When we backpropagate logp(y|x;θ) to get ∇θlogp(y|x;θ), we need to calculate the derivative of the logp(y|x;θ) with respect to intermediate stages of the computation at each layer. So as we go backward through the computational graph, once we get to the output of Wlal−1 , we'll have computed ∇Wlal−1logp(y|x;θ).
By the chain rule, and using the fact that ∇WlWlal−1=al−1, ∇Wllogp(y|x;θ)=∇Wlal−1logp(y|x;θ)⋅aTl−1
This means we can decompose the gradient[3] of the log-likelihood loss on some data point (x,y) with respect to the weight matrix W into the intermediate gradients of the loss with respect to the output of applying the weight matrix and the activations prior to that layer.
Working with gradients of weight matrices is inconvenient though, as we end up with 3D tensors for the Jacobian. We can instead consider θl, the unrolled weight matrix for layer l.
Then, defining Dv=∇vlogp(y|x;θ), sl=Wlal−1 , and ⊗ as the Kronecker product:
Dθl=al−1⊗Dsl
So far, so exact... But now, time for approximations. KFAC makes things simpler by assuming:
This allows us to write down a simple block-diagonal approximation for G:
Gl=E[DθlDθTl]=E[al−1aTl−1⊗DslDsTl]≈E[al−1al−1]⊗E[DslDsTl]=Al−1⊗Sl
Where Al−1 and Sl are uncentered covariance matrices for the layer's input activations and pre-nonlinearity gradients, respectively.
This structure enables us to efficiently get the inverse (approximate) Gauss-Newton Hessian vector product:
Let Vl denote the entries of v for layer l, reshaped to match Wl, and let vl=vec(Vl)
Using various Kronecker product identities, we can compute the inverse (approximate) Gauss-Newton Hessian vector product as:
^G−1lvl=vec(S−1lVlA−1l−1)
Eigenvalue correction
We made an approximation earlier when we went from E[al−1aTl−1⊗DslDsTl] to E[al−1al−1]⊗E[DslDsTl]
Using the eigendecompositions of A and S:
A=QAΛAQTA
S=QSΛSQTS
we can write a more accurate expression for G:
G≈(QA⊗QS)Λ(QA⊗QS)T
where the diagonal matrix Λ is defined as:
Λii=E[((QA⊗QS)Dθ)2i]=E[(QSvec(Dθ)QTA)2i]
which "captures the variances of the pseudo-gradient projected onto each eigenvector of the K-FAC approximation".
We can get the damped inverse Gauss-Newton Hessian vector product approximation by adding λ to the eigenvalues, obtaining:
(G+λI)−1v≈(QA⊗QS)(Λ+λI)−1(QA⊗QS)Tv
=vec(QTS[(QSVQTA)⊘unvec(diag−1(Λ+λI))]QA)
Influence functions for autoregressive models
A few details change when we want to calculate Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs) for a Transformer language model trained with an autoregressive loss function.
In this case, the property of interest f considered (the thing we are calculating the influence on) is the log-likelihood of a particular token string completion zc, given a token string prompt zp[4]:
log p(zc|zp;θ)
The paper only considers measuring the influence on a subset of the Transformer's weights - only the MLP layers - so the MLP G approximation derived above applies almost exactly.
However, the parameter gradients are now summed over token indices:
Dθl=∑Tt=1Dθl,t=∑Tt=1al−1,t⊗Dsl,t
Each diagonal block of G is given by E[DθlDθTl], however we want to take into account how this second moment is affected by the inter-token correlations and so cannot as accurately directly approximate with E[al−1al−1]⊗E[DslDsTl] as before.
The paper presents the following middle-ground between efficiency and accuracy:
Implementing in PyTorch
As described above, the key ingredients for If are:
We can get 1) and 3) by simply fetching
parameter.grad
[5] after performing a backward pass of the loss on some input, target pair.We can get 2a) using a forward hook that saves the input to a layer during the forward pass. We can get 2b) using a backward hook on the linear layer that saves the gradient wrt the linear layer's output.
You can find my implementation attempt on GitHub here [6]- includes code applying influence functions analysis to a vanilla MLP trained on MNIST and a 2-layer transformer trained on a basic next character prediction task.
Results of small experiment on MNIST
I trained an MLP on MNIST (with flattened images) and then used the influence function approximation code to extract influential training examples for particular predicted test set labels.
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image. Not all top influential images shared the same label as the query. I only searched a subset of the training corpus, for efficiency.
Here are some examples, filtered by cases where the influence was non-negligible (some queries returned ~0 for all sampled training datapoints) (first image on left is query, followed by top most influential training images given by the approximation):
The warm-start problem referenced by Bae et al. refers to the fact that for a not strictly convex objective, the influence of a training example in the neighborhood of a minimum θ∗ may be different from the influence at a different initialization point.
The paper uses homogeneous vector notation to account for biases / affine transformations - you can assume there is a 1 appended to the activations a and a bias vector appended to W to cover this case.
The paper refers to these as "pseudo-gradients" since they are sampled from the final output distribution and are distinct from gradients during training.
The zp, zc pair is referred to as the "query" in the paper, as we are "querying" which training examples were most influential for the model producing zc given zp.
Specifically, concatenate a linear layer's
.weight
and.bias
grad
sIf you look through the code and find any bugs (quite possible) or performance improvements (definitely findable; e.g. more batching + splitting of GPU ops - WIP) I'd be super happy to merge PRs and/or hear from you! I hope to gradually improve this codebase and run larger experiments.