Edited to fix errors pointed out by @JoshEngels and @Adam Karvonen (mainly: different definition for explained variance, details here).
Summary: K-means explains 72 - 87% of the variance in the activations, comparable to vanilla SAEs but less than better SAEs. I think this (bug-fixed) result is neither evidence in favour of SAEs nor against; the Clustering & SAE numbers make a straight-ish line on a log plot.
Epistemic status: This is a weekend-experiment I ran a while ago and I figured I should write it up to share. I have taken decent care to check my code for silly mistakes and "shooting myself in the foot", but these results are not vetted to the standard of a top-level post / paper.
SAEs explain most of the variance in activations. Is this alone a sign that activations are structured in an SAE-friendly way, i.e. that activations are indeed a composition of sparse features like the superposition hypothesis suggests?
I'm asking myself this questions since I initially considered this as pretty solid evidence: SAEs do a pretty impressive job compressing 512 dimensions into ~100 latents, this ought to mean something, right?
But maybe all SAEs are doing is "dataset clustering" (the data is cluster-y and SAEs exploit this)---then a different sensible clustering method should also be able do perform similarly well!
I took this[1] SAE graph from Neuronpedia, and added a K-means clustering baseline. Think of this as pretty equivalent to a top-k SAE (with k=1; in fact I added a point where I use the K-means centroids as features of a top-1 SAE which does slightly better than vanilla K-means with binary latents).
K-means clustering (which uses a single latent, L0=1) explains 72 - 87% of the variance. This is a good number to keep in mind when comparing to SAEs. However, this is significantly lower than SAEs (which often achieve 90%+). To have a comparison using more latents I'm adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as vanilla SAEs. The SAEBench upcoming paper also does a PCA baseline so I won't discuss PCA in detail here.
Here's the result for layers 3 and 4, and 4k and 16k latents. (These were the 4 SAEBench suites available on Neuronpedia.) There's two points each for the clustering results corresponding to 100k and 1M training samples. Code here.
What about interpretability? Clusters seem "monosemantic" on a skim. In an informal investigation I looked at max-activating dataset examples, and they seem to correspond to related contexts / words like monosemantic SAE features tend to do. I haven't spent much time looking into this though.
Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length. After the edit I've made sure to use the same Variance Explained definition for all points.
A final caveat I want to mention is that I think the SAEs I'm comparing here (SAEBench suite for Pythia-70M) are maybe weak. They're only using 4k and 16k latents, for 512 embedding dimensions, using expansion ratios of 8 and 32, respectively (the best SAEs I could find for a ~100M model). But I also limit the number of clusters to the same numbers, so I don't necessarily expect the balance to change qualitatively at higher expansion ratios.
I want to thank @Adam Karvonen, @Lucius Bushnaq, @jake_mendel, and @Patrick Leask for feedback on early results, and @Johnny Lin for implementing an export feature on Neuronpedia for me! I also learned that @scasper proposed something similar here (though I didn't know about it), I'm excited for follow-ups implementing some of Stephen's advanced ideas (HAC, a probabilistic alg, ...).
I'm using the conventional definition of variance explained, rather than the one used by Neuronpedia, thus the numbers are slightly different. I'll include the alternative graph in a comment.
I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly.
https://gist.github.com/Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec#file-clustering_pythia-py-L309
I think FVU is correctly computed by subtracting the mean from each dimension when computing the denominator. See the SAEBench impl here:
https://github.com/adamkarvonen/SAEBench/blob/5204b4822c66a838d9c9221640308e7c23eda00a/sae_bench/evals/core/main.py#L566
When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.
In general I think SAEs with low k should be at least as good as k means clustering, and if it's not I'm a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).
Here's my clustering code: https://github.com/JoshEngels/CheckClustering/blob/main/clustering.py
You're right. I forgot subtracting the mean. Thanks a lot!!
I'm computing new numbers now, but indeed I expect this to explain my result! (Edit: Seems to not change too much)
After adding the mean subtraction, the numbers haven't changed too much actually -- but let me make sure I'm using the correct calculation. I'm gonna follow your and @Adam Karvonen's suggestion of using the SAE bench code and loading my clustering solution as an SAE (this code).
These logs show numbers with the original / corrected explained variance computation; the difference is in the 3-8% range.
v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8887 / 0.8568
v3 (KMeans): Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.9020 / 0.8740
v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4096, variance explained = 0.8044 / 0.7197
v3 (KMeans): Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16384, variance explained = 0.8261 / 0.7509
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8910 / 0.8599
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.9041 / 0.8766
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8948 / 0.8647
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.9076 / 0.8812
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.9044 / 0.8770
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.9159 / 0.8919
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.9121 / 0.8870
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.9232 / 0.9012
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.9209 / 0.8983
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.9314 / 0.9118
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.9379 / 0.9202
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.9468 / 0.9315
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9539 / 0.9407
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9611 / 0.9499
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9721 / 0.9641
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9768 / 0.9702
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9999 / 0.9998
PCA+Clustering: Layer blocks.3.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9999 / 0.9999
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4095, n_pca=1, variance explained = 0.8077 / 0.7245
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16383, n_pca=1, variance explained = 0.8292 / 0.7554
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4094, n_pca=2, variance explained = 0.8145 / 0.7342
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16382, n_pca=2, variance explained = 0.8350 / 0.7636
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4091, n_pca=5, variance explained = 0.8244 / 0.7484
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16379, n_pca=5, variance explained = 0.8441 / 0.7767
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4086, n_pca=10, variance explained = 0.8326 / 0.7602
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16374, n_pca=10, variance explained = 0.8516 / 0.7875
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4076, n_pca=20, variance explained = 0.8460 / 0.7794
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16364, n_pca=20, variance explained = 0.8637 / 0.8048
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=4046, n_pca=50, variance explained = 0.8735 / 0.8188
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16334, n_pca=50, variance explained = 0.8884 / 0.8401
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3996, n_pca=100, variance explained = 0.9021 / 0.8598
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16284, n_pca=100, variance explained = 0.9138 / 0.8765
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3896, n_pca=200, variance explained = 0.9399 / 0.9139
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=16184, n_pca=200, variance explained = 0.9473 / 0.9246
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=3596, n_pca=500, variance explained = 0.9997 / 0.9996
PCA+Clustering: Layer blocks.4.hook_resid_post, n_tokens=1000000, n_clusters=15884, n_pca=500, variance explained = 0.9998 / 0.9997
this seems concerning.
I feel like my post appears overly dramatic; I'm not very surprised and don't consider this the strongest evidence against SAEs. It's an experiment I ran a while ago and it hasn't changed my (somewhat SAE-sceptic) stance much.
But this is me having seen a bunch of other weird SAE behaviours (pre-activation distributions are not the way you'd expect from the superposition hypothesis h/t @jake_mendel, if you feed SAE-reconstructed activations back into the encoder the SAE goes nuts, stuff mentioned in recent Apollo papers, ...).
Reasons this could be less concerning that it looks
I should really run a random Gaussian data baseline for this.
Tentatively I get similar results (70-85% variance explained) for random data -- I haven't checked that code at all though, don't trust this. Will double check this tomorrow.
(In that case SAE's performance would also be unsurprising I suppose)
I'm not sure what you mean by "K-means clustering baseline (with K=1)". I would think the K in K-means stands for the number of means you use, so with K=1, you're just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
But anyway, under my current model (roughly Why I'm bearish on mechanistic interpretability: the shards are not in the network + Binary encoding as a simple explicit construction for superposition) it seems about as natural to use K-means as it does to use SAEs, and not necessarily an issue if K-means outperforms SAEs. If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space, then K-means seems like a perfectly cromulent quantization for identifying these volumes. The major issue is where we go from here.
If we imagine that the meaning is given not by the dimensions of the space but rather by regions/points/volumes of the space
I think this is what I care about finding out. If you're right this is indeed not surprising nor an issue, but you being right would be a major departure from the current mainstream interpretability paradigm(?).
The question of regions vs compositionality is what I've been investigating with my mentees recently, and pretty keen on. I'll want to write up my current thoughts on this topic sometime soon.
I'm not sure what you mean by "K-means clustering baseline (with K=1)". I would think the K in K-means stands for the number of means you use, so with K=1, you're just taking the mean direction of the weights. I would expect this to explain maybe 50% of the variance (or less), not 90% of the variance.
Thanks for pointing this out! I confused nomenclature, will fix!
Edit: Fixed now. I confused
I think he messed up the lingo a bit, but looking at the code he seems to have done k-means with a number of clusters similar to the number of SAE latents, which seems fine.
I'm going to update the results in the top-level comment with the corrected data; I'm pasting the original figures here for posterity / understanding the past discussion. Summary of changes:
Old (no longer true) text:
It turns out that even clustering (essentially L_0=1) explains up to 90% of the variance in activations, being matched only by SAEs with L_0>100. This isn't an entirely fair comparison, since SAEs are optimised for the large-L_0 regime, while I haven't found a L_0>1 operationalisation of clustering that meaningfully improves over L_0=1. To have some comparison I'm adding a PCA + Clustering baseline where I apply a PCA before doing the clustering. It does roughly as well as expected, exceeding the SAE reconstruction for most L0 values. The SAEBench upcoming paper also does a PCA baseline so I won't discuss PCA in detail here.
[...]Here's the code used to get the clustering & PCA below; the SAE numbers are taken straight from Neuronpedia. Both my code and SAEBench/Neuronpedia use OpenWebText with 128 tokens context length so I hope the numbers are comparable, but there's a risk I missed something and we're comparing apples to oranges.
I think the relation between K-means and sparse dictionary learning (essentially K-means is equivalent to an L_0=1 constraint) is already well-known in the sparse coding literature? For example see this wiki article on K-SVD (a sparse dictionary learning algorithm) which first reviews this connection before getting into the nuances of k-SVD.
Were the SAEs for this comparison trained on multiple passes through the data, or just one pass/epoch? Because if for K-means you did multiple passes through the data but for SAEs just one then this feels like an unfair comparison.
What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.
What do you mean you’re encoding/decoding like normal but using the k means vectors?
So I do something like
latents_tmp = torch.einsum("bd,nd->bn", data, centroids)
max_latent = latents_tmp.argmax(dim=-1) # shape: [batch]
latents = one_hot(max_latent)
where the first line is essentially an SAE embedding (and centroids are the features), and the second/third line is a top-k. And for reconstruction do something like
recon = centroids @ latents
which should also be equivalent.
Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then?
Yes I would expect an optimal k=1 top-k SAE to find exactly that solution. Confused why k=20 top-k SAEs to so badly then.
If this is a crux then a quick way to prove this would be for me to write down encoder/decoder weights and throw them into a standard SAE code. I haven't done this yet.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
Definitely take this with a grain of salt, I'm going to look through my code and see if I can reproduce your results on pythia too, and if so try on a larger model to. Code: https://github.com/JoshEngels/CheckClustering/tree/main
Collection of some mech interp knowledge about transformers:
Writing up folk wisdom & recent results, mostly for mentees and as a link to send to people. Aimed at people who are already a bit familiar with mech interp. I've just quickly written down what came to my head, and may have missed or misrepresented some things. In particular, the last point is very brief and deserves a much more expanded comment at some point. The opinions expressed here are my own and do not necessarily reflect the views of Apollo Research.
Transformers take in a sequence of tokens, and return logprob predictions for the next token. We think it works like this:
This is a nice overview, thanks!
Lee Sharkey's CLDR arguments
I don't think I've seen the CLDR acronym before, are the arguments publicly written up somewhere?
Also, just wanted to flag that the links on 'this picture' and 'motivation image' don't currently work.
CLDR (Cross-layer distributed representation): I don't think Lee has written his up anywhere yet so I've removed this for now.
Also, just wanted to flag that the links on 'this picture' and 'motivation image' don't currently work.
Thanks for the flag! It's these two images, I realize now that they don't seem to have direct links
Images taken from AMFTC and Crosscoders by Anthropic.
Thanks for the great writeup.
Superposition ("local codes") require sparsity, i.e. that only few features are active at a time.
Typo: I think you meant to write distributed, not local, codes. A local code is the opposite of superposition.
Thanks! You're right, totally mixed up local and dense / distributed. Decided to just leave out that terminology
We think it works like this
Who is "we"? Is it:
Also, this definitely deserves to be made into a high-level post, if you end up finding the time/energy/interest in making one.
Thanks for the comment!
I think this is what most mech interp researchers more or less think. Though I definitely expect many researchers would disagree with individual points, nor does it fairly weigh all views and aspects (it's very biased towards "people I talk to"). (Also this is in no way an Apollo / Apollo interp team statement, just my personal view.)
PSA: People use different definitions of "explained variance" / "fraction of variance unexplained" (FVU)
is the formula I think is sensible; the bottom is simply the variance of the data, and the top is the variance of the residuals. The indicates the norm over the dimension of the vector . I believe it matches Wikipedia's definition of FVU and R squared.
is the formula used by SAELens and SAEBench. It seems less principled, @Lucius Bushnaq and I couldn't think of a nice quantity it corresponds to. I think of it as giving more weight to samples that are close to the mean, kind-of averaging relative reduction in difference rather than absolute.
A third version (h/t @JoshEngels) which computes the FVU for each dimension independently and then averages, but that version is not used in the context we're discussing here.
In my recent comment I had computed my own , and compared it to FVUs from SAEBench (which used ) and obtained nonsense results.
Curiously the two definitions seem to be approximately proportional—below I show the performance of a bunch of SAEs—though for different distributions (here: activations in layer 3 and 4) the ratio differs.[1] Still, this means using instead of to compare e.g. different SAEs doesn't make a big difference as long as one is consistent.
Thanks to @JoshEngels for pointing out the difference, and to @Lucius Bushnaq for helpful discussions.
If a predictor doesn't perform systematically better or worse at points closer to the mean then this makes sense. The denominator changes the relative weight of different samples but this doesn't have any effect beyond noise and a global scale, as long as there is no systematic performance difference.
I would be very surprised if this FVU_B actually another definition and not a bug. It's not a fraction of the variance and those denominators can easily be zero or very near zero.
https://github.com/jbloomAus/SAELens/blob/main/sae_lens/evals.py#L511 sums the numerator and denominator separately, if they aren't doing that in some other place probably just file a bug report?
I think this is the sum over the vector dimension, but not over the samples. The sum (mean) over samples is taken later in this line which happens after the division
metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()
Edit: And to clarify, my impression is that people think of this as alternative definitions of FVU and you got to pick one, rather than one being right and one being a bug.
Edit2: And I'm in touch with the SAEBench authors about making a PR to change this / add both options (and by extension probably doing the same in SAELens); though I won't mind if anyone else does it!
FVU_B doesn't make sense but I don't see where you're getting FVU_B from.
Here's the code I'm seeing:
resid_sum_of_squares = (
(flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1)
)
total_sum_of_squares = (
(flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1)
)
mse = resid_sum_of_squares / flattened_mask.sum()
explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares
Explained variance = 1 - FVU = 1 - (residual sum of squares) / (total sum of squares)
The previous lines calculate the ratio (or 1-ratio) stored in the “explained variance” key for every sample/batch. Then in that later quoted line, the list is averaged, I.e. we”re taking the sample average over the ratio. That’s the FVU_B formula.
Let me know if this clears it up or if we’re misunderstanding each other!
Why I'm not too worried about architecture-dependent mech interp methods:
I've heard people argue that we should develop mechanistic interpretability methods that can be applied to any architecture. While this is certainly a nice-to-have, and maybe a sign that a method is principled, I don't think this criterion itself is important.
I think that the biggest hurdle for interpretability is to understand any AI that produces advanced language (>=GPT2 level). We don't know how to write a non-ML program that speaks English, let alone reason, and we have no idea how GPT2 does it. I expect that doing this the first time is going to be significantly harder, than doing this the 2nd time. Kind of how "understand an Alien mind" is much harder than "understand the 2nd Alien mind".
Edit: Understanding an image model (say Inception V1 CNN) does feel like a significant step down, in the sense that these models feel significantly less "smart" and capable than LLMs.
Agreed. I do value methods being architecture independent, but mostly just because of this:
and maybe a sign that a method is principled
At scale, different architectures trained on the same data seem to converge to learning similar algorithms to some extent. I care about decomposing and understanding these algorithms, independent of the architecture they happen to be implemented on. If a mech interp method is formulated in a mostly architecture independent manner, I take that as a weakly promising sign that it's actually finding the structure of the learned algorithm, instead of structure related to the implementation on one particular architecture.
I've heard people argue that we should develop mechanistic interpretability methods that can be applied to any architecture.
I think the usual reason this claim is made is because the person making the claim thinks it's very plausible LLMs aren't the paradigm that lead to AGI. If that's the case, then interpretability that's indexed heavily on them gets us understanding of something qualitatively weaker than we'd like. I agree that there'll be some transfer, but it seems better and not-very-hard to talk about how well different kinds of work transfer.
Agreed. A related thought is that we might only need to be able to interpret a single model at a particular capability level to unlock the safety benefits, as long as we can make a sufficient case that we should use that model. We don't care inherently about interpreting GPT-4, we care about there existing a GPT-4 level model that we can interpret.
List of some larger mech interp project ideas (see also: short and medium-sized ideas). Feel encouraged to leave thoughts in the replies below!
Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!
What is going on with activation plateaus: Transformer activations space seems to be made up of discrete regions, each corresponding to a certain output distribution. Most activations within a region lead to the same output, and the output changes sharply when you move from one region to another. The boundaries seem to correspond to bunched-up ReLU boundaries as predicted by grokking work. This feels confusing. Are LLMs just classifiers with finitely many output states? How does this square with the linear representation hypothesis, the success of activation steering, logit lens etc.? It doesn't seem in obvious conflict, but it feels like we're missing the theory that explains everything. Concrete project ideas:
Use sensitive directions to find features: Can we use the sensitivity of directions as a way to find the "true features", some canonical basis of features? In a recent post we found current SAE features to look less special that expected, so I'm a bit cautious about this. But especially after working on some toy models about computation in superposition I'd be keen to explore the error correction predictions made here (paper, comment).
Test of we can fully sparsify a small model: Try the full pipeline of training SAEs everywhere, or training Transcoders & Attention SAEs, and doing all that such that connections between features are sparse (such that every feature only interacts with a few other features). The reason we want that is so that we can have simple computational graphs, and find simple circuits that explain model behaviour.
I expect that---absent of SAE improvements finding the "true feature" basis---you'll need to train them all together with a penalty for the sparsity of interactions. To be concrete, an inefficient thing you could do is the following: Train SAEs on every residual stream layer, with a loss term that L1 penalises interactions between adjacent SAE features. This is hard/inefficient because the matrix of SAE interactions is huge, plus you probably need attributions to get these interactions which are expensive to compute (at every training step!). I think the main question for this project is to figure out whether there is a way to do this thing efficiently. Talk to Logan Smith, Callum McDoughall, and I expect there are a couple more people who are trying something like this.
Are the features learned by the model the same as the features learned by SAEs?
TL;DR: I want true features model-features to be a property of the model weights, and to be recognizable without access to the full dataset. Toy models have that property. My “poor man’s model-features” have it. I want to know whether SAE-features have this property too, or if SAE-features do not match the true features model-features.
Introduction: Neural networks likely encode features in superposition. That is, features are represented as directions in activation space, and the model likely tracks many more features than dimensions in activation space. Because features are sparse, it should still be possible for the model to recover and use individual feature values.[1]
Problem statement: The prevailing method for finding these features are Sparse Autoencoders (SAEs). SAEs are well-motivated because they do recover superposed features in toy models. However, I am not certain whether SAEs recover the features of LLMs. I am worried (though not confident) that SAEs do not recover the features of the model (but the dataset), and that we are thus overconfident in how much SAEs tell us.
SAE failure mode: SAEs are trained to achieve a certain compression[2] task: Compress activations into a sparse overcomplete basis, and reconstruct the original activations based on this compressed representation. The solution to this problem can be identical to what the neural network does (wanting to store & use information), but it not necessarily is. In TMS, the network’s only objective is to compress features, so it is natural that the SAE-features match the model-features. But LLMs solve a different task (well, we don’t have a good idea what LLMs do), and training an SAE on a model’s activations might yield a basis different from the model-features (see hypothetical Example 1 below).
Operationalisation of model-features (I’m tabooing “true features”): In the Toy Model of Superposition (TMS) the model’s weights are clearly adjusted to the features directions. We can tell a feature from looking at the model weights. I want this to be a property of true SAE-features as well. Then I would be confident that the features are a property of the model, and not (only) of the dataset distribution. Concrete operationalisation:
Why do I care? I expect that the model-features are, in some sense, the computational units of the model. I expect our understanding to be more accurate (and to generalize) if we understand what the model actually does internally (see hypothetical Example 2 below).
Is this possible? Toy models of computation in superposition seem to suggest that models give special treatment to feature directions (compared to arbitrary activation directions), for example the error correction described here. This may privilege the basis of model-features over other decompositions of activations. I discuss experiment proposals at the bottom.
Example 1: Imagine an LLM was trained on The Pile excluding Wikipedia. Now we train an SAE on the model’s activations on a different dataset including Wikipedia. I expect that the SAE will find Wikipedia-related features: For example, a Wikipedia-citation-syntax feature on a low level, or an Wikipedia-style-objectivity feature on a high level. I would claim that this is not a feature of the model: During training the model never encountered these concepts, it has not reserved a direction in its superposition arrangement (think geometric shapes in Toy Model of Superposition) for this feature.
Example 2: Maybe an SAE trained on an LLM playing Civilization and Risk finds a feature that corresponds to “strategic deception” on this dataset. But actually the model does not use a “strategic deception” feature (instead strategic deception originates from some, say, the “power dynamics” feature), and it just happens that the instances of strategic deception in those games clustered into a specific direction. If we now take this direction to monitor for strategic deception we will fail to notice other strategic deception originating from the same “power dynamics” features.
Experiment proposals: I have explored the abnormal effect that “poor man’s model-features” (sampled as the difference between two independent model activations) have on model outputs, and their relation to theoretically predicted noise suppression in feature activations. Experiments in Gurnee (2024) and Lindsey (2024) suggest that SAE decoder errors and SAE-features also have an abnormal effect on the model. With the LASR Labs team I mentor I want to explore whether SAE-features match the theoretical predictions, and whether the SAE-feature effects match those expected from model-features.
My core request is that I want (SAE-)features to be a property of the model, rather than the dataset.
Of course a concept being common in the model-training-data makes it likely (?) to be a concept the model uses, but I don’t think this is a 1:1 correspondence. (So just making the SAE training set equal to the model training set wouldn’t solve the issue.)
There is a view that SAE features are just a useful tool for describing activations (interpretable features) and manipulating activations (useful for steering and probing). That SAEs are just a particularly good method in a larger class of methods, but not uniquely principled. In that case I wouldn't expect this connection to model behaviour.
But often we make the claim that we often make is that the model sees and understands the world as a set of model-features, and that we can see the same features by looking at SAE-features of the activations. And then I want to see the extra evidence.
List of some medium-sized mech interp project ideas (see also: shorter and longer ideas). Feel encouraged to leave thoughts in the replies below!
Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!
Toy model of Computation in Superposition: The toy model of computation in superposition (CIS; Circuits-in-Sup, Comp-in-Sup post / paper) describes a way in which NNs could perform computation in superposition, rather than just storing information in superposition (TMS). It would be good to have some actually trained models that do this, in order (1) to check whether NNs learn this algorithm or a different one, and (2) to test whether decomposition methods handle this well.
This could be, in the simplest form, just some kind of non-trivial memorisation model, or AND-gate model. Just make sure that the task does in fact require computation, and cannot be solved without the computation. A more flashy versions could be a network trained to do MNIST and FashionMNIST at the same time, though this would be more useful for goal (2).
Transcoder clustering: Transcoders are a sparse dictionary learning method that e.g. replaces an MLP with an SAE-like sparse computation (basically an SAE but not mapping activations to itself but to the next layer). If the above model of computation / circuits in superposition is correct (every computation using multiple ReLUs for redundancy) then the transcoder latents belonging to one computation should co-activate. Thus it should be possible to use clustering of transcoder activation patterns to find meaningful model components (circuits in the circuits-in-superposition model). (Idea suggested by @Lucius Bushnaq, mistakes are mine!) There's two ways to do this project:
Investigating / removing LayerNorm (LN): For GPT2-small I showed that you can remove LN layers gradually while fine-tuning without loosing much model performance (workshop paper, code, model). There are three directions that I want to follow-up on this project.
List of some short mech interp project ideas (see also: medium-sized and longer ideas). Feel encouraged to leave thoughts in the replies below!
Edit: My mentoring doc has more-detailed write-ups of some projects. Let me know if you're interested!
Directly testing the linear representation hypothesis by making up a couple of prompts which contain a few concepts to various degrees and test
Mostly I expect this to come out positive, and not to be a big update, but seems cheap to check.
SAEs vs Clustering: How much better are SAEs than (other) clustering algorithms? Previously I worried that SAEs are "just" finding the data structure, rather than features of the model. I think we could try to rule out some "dataset clustering" hypotheses by testing how much structure there is in the dataset of activations that one can explain with generic clustering methods. Will we get 50%, 90%, 99% variance explained?
I think a second spin on this direction is to look at "interpretability" / "mono-semanticity" of such non-SAE clustering methods. Do clusters appear similarly interpretable? I This would address the concern that many things look interpretable, and we shouldn't be surprised by SAE directions looking interpretable. (Related: Szegedy et al., 2013 look at random directions in an MNIST network and find them to look interpretable.)
Activation steering vs prompting: I've heard the view that "activation steering is just fancy prompting" which I don't endorse in its strong form (e.g. I expect it to be much harder for the model to ignore activation steering than to ignore prompt instructions). However, it would be nice to have a prompting-baseline for e.g. "Golden Gate Claude". What if I insert a "<system> Remember, you're obsessed with the Golden Gate bridge" after every chat message? I think this project would work even without the steering comparison actually.
Why I'm not that hopeful about mech interp on TinyStories models:
Some of the TinyStories models are open source, and manage to output sensible language while being tiny (say 64dim embedding, 8 layers). Maybe it'd be great to try and thoroughly understand one of those?
I am worried that those models simply implement a bunch of bigrams and trigrams, and that all their performance can be explained by boring statistics & heuristics. Thus we would not learn much from fully understanding such a model. Evidence for this is that the 1-layer variant, which due to it's size can only implement bigrams & trigram-ish things, achieves a better loss than many of the tall smaller models (Figure 4). Thus it seems not implausible that most if not all of the performance of all the models could be explained by similarly simple mechanisms.
Folk wisdom is that the TinyStories dataset is just very formulaic and simple, and therefore models without any sophisticated methods can appear to produce sensible language. I haven't looked into this enough to understand whether e.g. TinyStories V2 (used by TinyModel) is sufficiently good to dispel this worry.
Has anyone tested whether feature splitting can be explained by composite (non-atomic) features?
There’s at least two hypotheses for what is going on.
Anthropic conjectures hypothesis 1 in Towards Monosemanticity. Demian Till argues for hypothesis 2 in this post. I find Demian’s arguments compelling. They key idea is that an SAE can achieve lower loss by creating composite features for frequently co-occurring concepts: The composite feature fires instead of two (or more) atomic features, providing a higher sparsity (lower sparsity penalty) at the cost of taking up another dictionary entry (worse reconstruction).
Do we have good evidence for the one or the other case?
We observe that split features often have high cosine similarity, but this is explained by both hypotheses. (Anthropic says features are clustered together because they’re similar. Demian Till’s hypothesis would claim that multiple composite features contain the same atomic features, again explaining the similarity.)
A naive test may be to test whether features can be explained by a sparse linear combination of other features, though I’m not sure how easy this would be to test.
For reference, cosine similarity of SAE decoder directions in Joseph Bloom's GPT2-small SAEs, blocks.1.hook_resid_pre
and blocks.10.hook_resid_pre
compared to random directions and random directions with the same covariance as typical activations.
I like this recent post about atomic meta-SAE features, I think these are much closer (compared against normal SAEs) to what I expect atomic units to look like:
https://www.lesswrong.com/posts/TMAmHh4DdMr4nCSr5/showing-sae-latents-are-not-atomic-using-meta-saes
I think we should think more about computation in superposition. What does the model do with features? How do we go from “there are features” to “the model outputs sensible things”? How do MLPs retrieve knowledge (as commonly believed) in a way compatible with superposition (knowing more facts than number of neurons)?
This post (and paper) by @Kaarel, @jake_mendel, @Dmitry Vaintrob (and @LawrenceC) is the kind of thing I'm looking for, trying to lay out a model of how computation in superposition could work. It makes somewhat-concrete predictions about the number and property of model features.
Why? Because (a) these feature properties may help us find the features of a model (b) a model of computation may be necessary if features alone are not insufficient to address AI Safety (on the interpretability side).