Anthropic recently published the paper Studying Large Language Model Generalization with Influence Functions, which describes a scalable technique for measuring which training examples were most influential for a particular set of weights/outputs of a trained model. This can help us better understand model generalization, offering insights into the emergent properties of AI systems. For instance, influence functions could help us answer questions like "is the model asking not to be shut down because it has generalized that this is a generically good strategy for pursuing some goal, or simply because texts where AIs ask not to be shut down are commonly found in the training corpus?". 

In this post, I aim to summarize the approximations used in the paper to calculate the influence of different training examples and outline how the approximations can be implemented in PyTorch to form the basis of further research on influence functions by the AI safety community.  

(Note: most formulae are copied or adapted from the original paper, with a few additional derivation steps / simplified notation used for clarity in some places.)

Deriving the exact form of the influence function

Before we go into approximations, it is necessary to understand what specifically we are trying to measure. 

Given some element  of a dataset , we define the response function as the optimal solution  (weights that minimize expected loss ) as a function of the weighting  of this example. 

We define the influence  of  on  using the first-order Taylor approximation to the response function at .

We can get  the following way:

We know  is a minimum of  and so the gradient wrt  is zero at that point

Differentiating each side wrt :

(The LHS both directly depends on , and indirectly via , so we use the Implicit Function Theorem  

The second term can be simplified: 

And so we can rearrange to get an expression for :

This tells us how the optimal parameters change with a perturbation  to the weighting of an added data point . The change is proportional to the negative product of the inverse Hessian of the loss on all the data and the gradient of the loss on the data point in question with respect to the model parameters (both evaluated at the optimal parameters).

For simplicity, as in the paper, we'll denote  as .

Therefore,  (This corresponds to Equation 3 in the paper). 

Influence on some function of the model weights

So far, we have derived an expression for the influence of an added data point on the parameters . However, we are more interested in the influence of particular data points on some measurable properties of the model, such as the output logits or validation loss. We can see this as some function  of the trained parameters. 

 By the chain rule  and so

 (This corresponds to Equation 5 in the paper).

Problems with this expression

  • Hessian could have zeros and be not invertible (optimal parameters could be underspecified by loss function in case of overparameterized models)
  • We often don't train to convergence, so the first derivative of the loss wrt the parameters is then not necessarily zero, as previously assumed

The paper mentions that because of these problems, "past works have found influence functions to be inaccurate for modern neural networks."

How do we fix this?

One approach is to define a new objective that:

  • Has a single defined optimum in parameter space 
  • Is fully optimized when the model stops training

This is what the proximal Bregman objective (PBO) attempts to define. 

( here is the output of the model at the parameters  on input , and  is the output of the model at parameters  on input )

The PBO basically introduces a penalty for diverging too far from the initialized parameters, so there is some defined optimum that balances moving too far from the parameters at initialization and achieving good loss. 

So we can redefine the gradients used in  in terms of this new loss function that considers both the loss given a new training data point and the divergence from current parameters.

From Bae et al.'s 2022 paper If Influence Functions are the Answer, Then What is the Question?:

...while influence functions for neural networks are often a poor match to LOO [Leave One Out] retraining, they are a much better match to what we term the proximal Bregman response function (PBRF). Intuitively, the PBRF approximates the effect of removing a data point while trying to keep the predictions consistent with those of the (partially) trained model.

...

In addition, although the PBRF may not necessarily align with LOO retraining due to the warm-start[1], proximity, and non-convergence gaps, the motivating use cases for influence functions typically do not rely on exact LOO retraining. This means that the PBRF can be used in place of LOO retraining for many tasks such as identifying influential or mislabelled examples 

Applying the Implicit Function Theorem to the PBO, we can obtain an influence function with respect to the PBO objective (Equation 9 in the paper):

 

Where  is the Gauss-Newton Hessian  is the Jacobian - the first derivative of the network's outputs with respect to the parameters, and  is the Hessian of the loss with respect to the network's outputs. 

Efficient calculation

So, we want to get 

Let's assume we have the following:

  1. A trained network with parameters 
  2. An observable property of the network, , for instance, its output logits  for some chosen input 
  3. The training dataset  
  4. The loss function  the model was trained on 

The key ingredients needed to calculate  are:

  1. The gradient of the property  of interest with respect to the parameters , evaluated at 
  2. A way of getting the inverse damped Gauss-Newton Hessian vector product
  3. The gradient of the loss  on the training data points (which we want to calculate the influence of) with respect to the parameters , evaluated at 

Notice that only key ingredient 1 depends on the property of interest. We can pre-compute ingredients 2 and 3 and then use this to test a bunch of different properties (for example, find the most influential training examples for a bunch of different model input-output pairs). 

We can also calculate the influence as a batched operation over many training data points (batch  over multiple 's) to increase efficiency via vectorization. 

Which leaves the final key question: how do we get ?

Kronecker-Factored Approximate Curvature (KFAC)

Originally introduced in the 2015 paper Optimizing Neural Networks with Kronecker-factored Approximate Curvature by Martens and Grosse, KFAC is an approximation to the Fischer information matrix (FIM)  that can be inverted very efficiently. In the case of many models where the loss is given by the negative log probability associated with a simple predictive distribution, the FIM is equal to the Gauss-Newton Hessian.

KFAC for MLP models involves the following:

Given a fully connected model with  layers, let's assume each layer 's output is:

where  and  is a nonlinear activation function. [2]

When we backpropagate  to get , we need to calculate the derivative of the  with respect to intermediate stages of the computation at each layer.  So as we go backward through the computational graph, once we get to the output of  , we'll have computed 

By the chain rule, and using the fact that 

This means we can decompose the gradient[3] of the log-likelihood loss on some data point  with respect to the weight matrix  into the intermediate gradients of the loss with respect to the output of applying the weight matrix and the activations prior to that layer.

Working with gradients of weight matrices is inconvenient though, as we end up with 3D tensors for the Jacobian. We can instead consider , the unrolled weight matrix for layer 

Then, defining  , and  as the Kronecker product:

So far, so exact... But now, time for approximations. KFAC makes things simpler by assuming:

  • Gradients  are uncorrelated between different layers
  • Activations  are independent of pre-activation gradients 

This allows us to write down a simple block-diagonal approximation for :

Where  and  are uncentered covariance matrices for the layer's input activations and pre-nonlinearity gradients, respectively. 

This structure enables us to efficiently get the inverse (approximate) Gauss-Newton Hessian vector product:

Let  denote the entries of  for layer , reshaped to match , and let 

Using various Kronecker product identities, we can compute the inverse (approximate) Gauss-Newton Hessian vector product as: 

Eigenvalue correction

We made an approximation earlier when we went from  to 

Using the eigendecompositions of  and :

we can write a more accurate expression for :

where the diagonal matrix  is defined as:

  

which "captures the variances of the pseudo-gradient projected onto each eigenvector of the K-FAC approximation".

We can get the damped inverse Gauss-Newton Hessian vector product approximation by adding  to the eigenvalues, obtaining:

Influence functions for autoregressive models

A few details change when we want to calculate  for a Transformer language model trained with an autoregressive loss function. 

In this case, the property of interest  considered (the thing we are calculating the influence on) is the log-likelihood of a particular token string completion , given a token string prompt [4]:

The paper only considers measuring the influence on a subset of the Transformer's weights - only the MLP layers - so the MLP  approximation derived above applies almost exactly. 

However, the parameter gradients are now summed over token indices:

Each diagonal block of  is given by , however we want to take into account how this second moment is affected by the inter-token correlations and so cannot as accurately directly approximate with   as before. 

The paper presents the following middle-ground between efficiency and accuracy:

We first fit the covariance factors  and  as if the tokens were fully independent, and compute their respective eigendecompositions. Then, when fitting the diagonal matrix , we use the exact pseudo-gradients  which are summed over tokens. This way, at least the estimated diagonal entries of the moments in the Kronecker eigenbasis are unbiased.

Implementing in PyTorch

As described above, the key ingredients for  are:

  1. The gradient of the property  of interest with respect to the parameters , evaluated at 
  2. The inverse damped Gauss-Newton Hessian, which we can calculate from the expectations of the following quantities:
    1.  - the MLP layer inputs
    2.  - the MLP pre-nonlinearity gradients (gradients of loss wrt output of linear transformation )
  3. The gradient of the loss  on the training data points (which we want to calculate the influence of) with respect to the parameters , evaluated at 

We can get 1) and 3) by simply fetching parameter.grad[5] after performing a backward pass of the loss on some input, target pair.

We can get 2a) using a forward hook that saves the input to a layer during the forward pass. We can get 2b) using a backward hook on the linear layer that saves the gradient wrt the linear layer's output. 

You can find my implementation attempt on GitHub here [6]- includes code applying influence functions analysis to a vanilla MLP trained on MNIST and a 2-layer transformer trained on a basic next character prediction task. 

Results of small experiment on MNIST

I trained an MLP on MNIST (with flattened images) and then used the influence function approximation code to extract influential training examples for particular predicted test set labels.

I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image. Not all top influential images shared the same label as the query. I only searched a subset of the training corpus, for efficiency. 

Here are some examples, filtered by cases where the influence was non-negligible (some queries returned ~0 for all sampled training datapoints) (first image on left is query, followed by top most influential training images given by the approximation):

  1. ^

    The warm-start problem referenced by Bae et al. refers to the fact that for a not strictly convex objective, the influence of a training example in the neighborhood of a minimum  may be different from the influence at a different initialization point. 

  2. ^

    The paper uses homogeneous vector notation to account for biases / affine transformations - you can assume there is a 1 appended to the activations  and a bias vector appended to  to cover this case.

  3. ^

    The paper refers to these as "pseudo-gradients" since they are sampled from the final output distribution and are distinct from gradients during training.

  4. ^

    The  pair is referred to as the "query" in the paper, as we are "querying" which training examples were most influential for the model producing  given .  

  5. ^

    Specifically, concatenate a linear layer's .weight and .bias grads

  6. ^

    If you look through the code and find any bugs (quite possible) or performance improvements (definitely findable; e.g. more batching + splitting of GPU ops - WIP) I'd be super happy to merge PRs and/or hear from you! I hope to gradually improve this codebase and run larger experiments.

New Comment
6 comments, sorted by Click to highlight new comments since:
[-]Troof1111

Thanks for this! One thing I don't understand about influence functions is: why should I care about the proximal Bregman objective? To interpret a model, I'm really interested in in the LOO retraining, right? Can we still say things like "it seems that the model relied on this training sample for producing this output" with the PBO interpretation?

I agree that approximating the PBO makes this method more lossy (not all interesting generalization phenomena can be found). However, I think we can still glean useful information about generalization by considering "retraining" from a point closer to the final model than random initialization. The downside is if, for example, some data was instrumental in causing a phase transition at some point in training, this will not be captured by the PBO approximation. 

Indeed, the paper concedes:

Influence functions are approximating the sensitivity to the training set locally around the final weights and might not capture nonlinear training phenomena 

Purely empirically, I think Anthropic's results indicate there are useful things that can be learnt, even via this local approximation:

One of the most consistent patterns we have observed is that the influential sequences reflect increasingly sophisticated patterns of generalization as the model scale increases. While the influential sequences for smaller models tend to have short overlapping sequences of tokens, the top sequences for larger models are related at a more abstract thematic level, and the influence patterns show increasing robustness to stylistic changes, including the language.

My intuition here is that even if we are not exactly measuring the counterfactual "what if this datum was not included in the training corpus?", we could be estimating "what type of useful information is the model extracting from training data that looks like this?". 

I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image.

It's pbzcnevat gb gur arnerfg cbvagf ba gur obhaqnel bs qvtvg-pyhfgref! Bonus points if you made your observation without that interpretation in mind. What if you do Jacobian regularization?

How do you know?

It's the same training datums I would look at to resolve an ambiguous case.

Thank you for this. How would you think about the pros/cons of influence functions vs activation patching or direct logit attribution in terms of localizing a behavior in the model?