A short summary of the paper is presented below.
This work was produced by Apollo Research in collaboration with Jordan Taylor (MATS + University of Queensland) .
TL;DR: We propose end-to-end (e2e) sparse dictionary learning, a method for training SAEs that ensures the features learned are functionally important by minimizing the KL divergence between the output distributions of the original model and the model with SAE activations inserted. Compared to standard SAEs, e2e SAEs offer a Pareto improvement: They explain more network performance, require fewer total features, and require fewer simultaneously active features per datapoint, all with no cost to interpretability. We explore geometric and qualitative differences between e2e SAE features and standard SAE features.
Introduction
Current SAEs focus on the wrong goal: They are trained to minimize mean squared reconstruction error (MSE) of activations (in addition to minimizing their sparsity penalty). The issue is that the importance of a feature as measured by its effect on MSE may not strongly correlate with how important the feature is for explaining the network's performance.
This would not be a problem if the network's activations used a small, finite set of ground truth features -- the SAE would simply identify those features, and thus optimizing MSE would have led the SAE to learn the functionally important features. In practice, however, Bricken et al. observed the phenomenon of feature splitting, where increasing dictionary size while increasing sparsity allows SAEs to split a feature into multiple, more specific features, representing smaller and smaller portions of the dataset. In the limit of large dictionary size, it would be possible to represent each individual datapoint as its own dictionary element.
Since minimizing MSE does not explicitly prioritize learning features based on how important they are for explaining the network's performance, an SAE may waste much of its fixed capacity on learning less important features. This is perhaps responsible for the observation that, when measuring the causal effects of some features on network performance, a significant amount is mediated by the reconstruction residual errors (i.e. everything not explained by the SAE) and not mediated by SAE features (Marks et al.).
Given these issues, it is therefore natural to ask how we can identify the functionally important features used by the network. We say a feature is functional important if it is important for explaining the network's behavior on the training distribution. If we prioritize learning functionally important features, we should be able to maintain strong performance with fewer features used by the SAE per datapoint as well as fewer overall features.
To optimize SAEs for these properties, we introduce a new training method. We still train SAEs using a sparsity penalty on the feature activations (to reduce the number of features used on each datapoint), but we no longer optimize activation reconstruction. Instead, we replace the original activations with the SAE output and optimize the KL divergence between the original output logits and the output logits when passing the SAE output through the rest of the network, thus training the SAE end-to-end (e2e).
One risk with this method is that it may be possible for the outputs of SAE_e2e to take a different computational pathway through subsequent layers of the network (compared with the original activations) while nevertheless producing a similar output distribution. For example, it might learn a new feature that exploits a particular transformation in a downstream layer that is unused by the regular network or that is used for other purposes. To reduce this likelihood, we also add terms to the loss for the reconstruction error between the original model and the model with the SAE at downstream layers in the network.
It's reasonable to ask whether our approach runs afoul of Goodhart's law ("When a measure becomes a target, it ceases to be a good measure") We contend that mechanistic interpretability should prefer explanations of networks (and the components of those explanations, such as features) that explain more network performance over other explanations. Therefore, optimizing directly for quantitative proxies of performance explained (such as CE loss difference, KL divergence, and downstream reconstruction error) is preferred.
Key Results
We train each SAE type on language models (GPT2-small and Tinystories-1M), and present three key findings (Figure 1):
- For the same level of performance explained, SAE_local requires activating more than twice as many features per datapoint compared to SAE_e2e+downstream and SAE_e2e.
- SAE_e2e+downstream performs equally well as SAE_e2e in terms of the number of features activated per datapoint, yet its activations take pathways through the network that are much more similar to SAE_local.
- SAE_local requires more features in total over the dataset to explain the same amount of network performance compared with SAE_e2e and SAE_e2e+ds.
Moreover, our automated interpretability and qualitative analyses reveal that SAE_e2e+ds features are at least as interpretable as SAE_local features, demonstrating that the improvements in efficiency do not come at the cost of interpretability. These gains nevertheless come at the cost of longer wall-clock time to train (see article for further details).
When comparing the reconstruction errors at each downstream layer after the SAE is inserted (Figure 2 below), we find that, even though SAE_e2es explain more performance per feature than SAE_locals, they have much worse reconstruction error of the original activations at each subsequent layer. This indicates that the activations following the insertion of SAE_e2e take a different path through the network than in the original model, and therefore potentially permit the model to achieve its performance using different computations from the original model. This possibility motivated the training of SAE_e2e+ds, which we see has extremely similar reconstruction errors compared to SAE_local. SAE_e2e+ds therefore has the desirable properties of both learning features that explain approximately as much network performance as SAE_e2e (Figure 1) while having reconstruction errors that are much closer to SAE_local.
We measure the cosine similarities between each SAE dictionary feature and next-closest feature in the same dictionary. While this does not account for potential semantic differences between directions with high cosine similarities, it serves as a useful proxy for feature splitting, since split features tend to be highly similar directions. We find that SAE_local has features that are more tightly clustered, suggesting higher feature splitting (Figure 3 below). Compared to SAE_e2e+ds the mean cosine similarity is 0.04 higher (bootstrapped 95% CI [0.037-0.043]); compared to SAE_e2e the difference is 0.166 (95% CI [0.163-0.168]). We measure this for all runs in our Pareto frontiers in Appendix A.7 (Figure 7), and find that this difference is not explained by SAE_local having more alive dictionary elements than e2e SAEs.
In the paper, we also explore some qualitative differences between SAE_local and SAE_e2e+ds.
Acknowledgements
Johnny Lin and Joseph Bloom for supporting our SAEs on https://www.neuronpedia.org/gpt2sm-apollojt and Johnny Lin for providing tooling for automated interpretability, which made the qualitative analysis much easier. Lucius Bushnaq, Stefan Heimersheim and Jake Mendel for helpful discussions throughout. Jake Mendel for many of the ideas related to the geometric analysis. Tom McGrath, Bilal Chughtai, Stefan Heimersheim, Lucius Bushnaq, and Marius Hobbhahn for comments on earlier drafts. Center for AI Safety for providing much of the compute used in the experiments.
Extras
- Library for training e2e (and vanilla) SAEs and reproducing our analysis (https://github.com/ApolloResearch/e2e_sae). All SAEs in the article can be loaded using this library, and we also provide raw SAE weights for many of our runs at https://huggingface. co/apollo-research/e2e-saes-gpt2.
- Weights and Biases report that links to training metrics for all runs https://api.wandb.ai/links/sparsify/evnqx8t6
- Neuronpedia page (h/t @Johnny Lin @Joseph Bloom) for interactively exploring many of the SAEs presented in the article (https://www.neuronpedia.org/gpt2sm-apollojt)
The e2e having different feature directions across seeds was quite the bummer, but then I thought "are the encoder directions different though?"
Intuitively the encoder directions affect which datapoints each feature activates on, and the decoder is the causal downstream effect. For e2e, we would expect widely different decoder directions because there are many free parameters (from some other work that showed SVD of gradients had many zero singular values, meaning moving in most directions don't effect the downstream loss), but not necessarily encoder directions.
If the encoder directions are similar across seeds, I'd trust them to inform relevant features for the model output (in cases where we don't care about connections w/ downstream layers).
However, I was not able to find the SAEs for various seeds.
Trying to replicate Cos-sim Plots
I downloaded the similar CE at layer 6 for all three types of SAEs & took their cos-sim (last column in figure 3).
I think your cos-sim metric gives different results if you take the max over the first or 2nd dimension (or equivalently swapped the order of decoders multiplied by each other). If so, I think this is because you might double-count or something? Regardless, I ended up doing some hungarian algorithm to take the overall max (but don't double-count), but it's on cpu, so I only did the first 10k/40k features. Below is results for both encoder & decoder, which do replicate the directional results.
Nonzero Features
Additionally I thought that some results were from counting nonzero features, which, for the encoder is some high-cos-sim features, and decoder is the low-cos-sim features weirdly enough.
Would appreciate if y'all upload any repeated seeds!
My code is temporarily hosted (for a few weeks maybe?) here.
I finally checked!
Here is the Jaccard similarity (ie similarity of input-token activations) across seeds
The e2e ones do indeed have a much lower jaccard sim (there normally is a spike at 1.0, but this is removed when you remove features that only activate <10 times).
I also (mostly) replicated the decoder similarity chart:
And calculated the encoder sim:
[I, again, needed to remove dead features (< 10 activations) to get the graphs here.]
So yes, I believe the original paper's claim that e2e features learn quite different features across seed... (read more)