Note: The second figure in this post originally contained a bug pointed out by @LawrenceC, which has since been fixed.

Summary

  • Sparse Autoencoders (SAEs) reveal interpretable features in the activation spaces of language models, but SAEs don’t reconstruct activations perfectly. We lack good metrics for evaluating which parts of model activations SAEs fail to reconstruct, which makes it hard to evaluate SAEs themselves. In this post, we argue that SAE reconstructions should be tested using well-established benchmarks to help determine what kinds of tasks they degrade model performance on.
  • We stress-test a recently released set of SAEs for each layer of the gpt2-small residual stream using randomly sampled tokens from Open WebText and the Lambada benchmark where the model must predict a specific next token. 
  • The SAEs perform well on prompts with context sizes up to the training context size, but their performance degrades on longer prompts. 
  • In contexts shorter than or equal to the training context, the SAEs that we study generally perform well. We find that the performance of our late-layer SAEs is worse than early-layer SAEs, but since the SAEs all have the same width, this may just be because there are more features to resolve in later layers and our SAEs don’t resolve them.
  • In contexts longer than the training context, SAE performance is poor in general, but it is poorest in earlier layers and best in later layers. 

Introduction

Last year, Anthropic and EleutherAI/Lee Sharkey's MATS stream showed that sparse autoencoders (SAEs) can decompose language model activations into human-interpretable features. This has led to a significant uptick in the number of people training SAEs and analyzing models with them. However, SAEs are not perfect autoencoders and we still lack a thorough understanding of where and how they miss information. But how do we know if an SAE is “good” other than the fact that it has features we can understand?

SAEs try to reconstruct activations in language models – but they don’t do this perfectly. Imperfect activation reconstruction can lead to substantial downstream cross-entropy (CE) loss increases. Generally “good” SAEs retrieve 80-99% of the CE loss (compared to a generous baseline of zero ablation), but only retrieving 80% of the CE loss is enough to substantially degrade the performance of a model to that of a much smaller model (per scaling laws).

The second basic metric often used in SAE evaluation is the average per-token  norm of the hidden layer of the autoencoder. Generally this is something in the range of ~10-60 in a “good” autoencoder, which means that the encoder is sparse. Since we don’t know how many features are active per token in natural language, it’s useful to at least ask how changes in  relate to changes in SAE loss values. If high-loss data have drastically different  from the SAE’s average performance during training, that can be evidence of either off-distribution data (compared to the training data)  or some kind of data with more complex information.

The imperfect performance of SAEs on these metrics could be explained in a couple of ways:

  • The fundamental assumptions of SAEs are mostly right, but we’re bad at training SAEs. Perhaps if we learn to train better SAEs, these problems will become less bad.
    • Perhaps we need to accept higher  norms (more features active per token). This would not be ideal for interpretability, though.
    • Perhaps there's part of the signal which is dense or hard for an SAE to learn and so we are systematically missing some kind of information. Maybe a more sophisticated sparsity enforcement could help with this.
  • The fundamental assumptions of SAE's are wrong on some level.

It’s important to determine if we can understand where SAE errors come from. Distinguishing between the above worlds will help us guide our future research directions. Namely:

  • Should we stop using SAEs? (the problems are not solvable)
  • Should we train SAEs differently? 
  • Should we try to solve very specific technical problems associated with SAE training? (to do so, we must identify those problems)

One way to get traction on this is to ask whether the errors induced by SAEs are random. If not, perhaps they are correlated with factors like:

  • The specific task or datapoint being predicted. (Perhaps the task is out of distribution compared to the dataset used to train the SAE)
  • Which layer is being reconstructed. (Different layers perform different tasks, and perhaps SAEs are better or worse at some of these) 
  • Training hyperparameters. (e.g., a fixed, short context length used in SAE training).

If we can find results that suggest that the errors aren’t random, we can leverage correlations found in the errors. These errors can at least help us move in a direction of understanding how robust SAEs are. Perhaps these errors teach us to train better SAEs, and move us into a world where we can interpret features which reconstruct model activations with high fidelity. 

In this post, we’re going to stress test the gpt2-small residual stream SAEs released by Joseph Bloom. We’ll run different random Open WebText tokens through the autoencoders and we’ll also test the encoders using the Lambada benchmark. Our biggest finding echoes Sam Marks' comments on the original post, where he preliminarily found that SAE performance degrades significantly when a context longer than the training context is used. 

Experiment Overview

Open WebText next-token prediction (randomly sampled data)

Our first experiment asks how context position affects SAE performance. We pass 100k tokens from open WebText through gpt2-small using a context length of 1024. We cache residual stream activations, then reconstruct them using Joseph’s SAEs, which were trained on Open WebText using a context length of 128. We also run separate forward passes of the full gpt2-small model while intervening on a single layer in the residual stream by replacing its activations with the SAE reconstruction activations.

We measure the following quantities:

For each of these, we take means, denoted e.g., . Means are taken over the batch dimension to give a per-token value.

We also examine error propagation from an SAE in an earlier layer to an SAE in a later layer. Here, we only examine context positions  100 (shorter than the training context). We replace the activations at an early layer during a forward pass of the model. We cache the downstream activations resulting from that intervention, then reconstruct those intervened-on activations with later layer SAEs.

Benchmarks (correct answer prediction)

In our second set of experiments, we take inspiration from the original gpt-2 paper. We examine the performance of gpt2-small as well as gpt2-small with SAE residual stream interventions. We examine the test split of Lambada (paperhuggingface), and in an appendix we examine the Common Noun (CN) partition of the Children’s Book Test (CBT, paperhuggingface). In both of these datasets, a prompt is provided which corresponds to a correct answer token, so here we can measure how the probability of the correct answer changes. As in the previous experiments, we also measure , MSE, and  – but here we average over context rather than looking per-token. We also measure , the change in the log-probability of the model predicting the correct token.

Furthermore, we are interested in seeing whether SAE features were consistently “helpful” or “unhelpful” to the model in predicting the next answer. For each feature, we find which tokens it is active on. The mean   over all of the contexts where the feature is active is computed. We also measure a “score” for each feature, such that for every prompt where the feature is active and  , it gets a +1; for every prompt where it is active and , it gets a -1. We then sum this over all prompts where a feature is active to get the feature score. 

How does context length affect SAE performance on randomly sampled data?

In the plot below, we show the  norm (left panels), the average MSE reconstruction error from the SAEs (middle panels), and the SAE feature activation  norm (right panels). The top panels are plotted vs. context position (colored by layer) and the bottom panels are plotted vs. layer (at a few context positions, colored by position at indices [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1022] – we do not examine the 0th or last index).

The top row of plots says that all metrics follow reasonable trends for context positions  128 (the size the model was trained on) and then explode at later context positions. The bottom row shows us that early layers in the model (~layers 1-4) have much worse degradation of performance at late positions than later layers do.

Below we see the same sort of behavior in the cross entropy (CE) loss (plotted in the same colors as the figure above). We run a forward pass and replace the chosen layer’s residual stream activations with the corresponding SAE reconstruction of those activations. We find the CE loss of this modified forward pass and subtract off the CE loss of a standard forward pass to get  .

This downstream measurement is consistent with the local measurements from the previous plot. The SAEs themselves perform poorly at context positions later than the training context, meaning that they’re not sparse and they don’t reconstruct the activations. As a result, when we replace residual stream activations with SAE outputs, the model’s ability to predict the next token is also degraded. It’s quite interesting that earlier layers of the model are increasingly susceptible to this effect!

It’s possible that the reason that the SAEs are struggling to reconstruct the residual stream contributions at later context positions is because the residual stream activations themselves come from a very different distribution. One way to examine this is to look at the average  norm of the residual stream activations being recovered. We plot  vs. layer in the left panel, and vs MSE in the right two panels (the right two panels are the same data with different coloring).

In the left panel, we see that later layers have larger activations (as expected through residual stream accumulation).  Later positions in the context have very slightly smaller norms than earlier positions in the context (yellow points are slightly lower than darker ones they lie over). Note that the BOS token has larger activations than later tokens and must be excluded from averages and analyses. In the right panel, we see again that context size is the primary culprit of poor SAE performance (MSE), while there’s no obvious dominant trend with layer or  norm.

Hopefully increasing the training context will fix this poor behavior we see at long contexts that we outline above.

SAE Downstream Error Propagation

We also briefly examine performance of an SAE when it tries to reconstruct activations that have already been affected by an upstream SAE intervention. We replace the activations in an early layer denoted by the x-axis in the figure below (so in the first column of the plots below, the layer 0 activations are replaced with their SAE reconstructions). We cache the downstream activations resulting from that intervention at each subsequent layer, then reconstruct the activations of each subsequent layer with an SAE (denoted by the y-axis). We specifically look at fractional changes compared to the case where there is no upstream intervention, e.g., the left plot below shows the quantity [(reconstruction MSE with upstream intervention) / (MSE reconstruction on standard residual stream) - 1]. Similar fractional changes are shown for SAE feature  in the middle panel. The right hand panel shows how the first intervention changes the downstream residual stream . Red colors signify larger quantities while blue signifies smaller quantities, and a value of 1 is an increase of 100%. We do not include tokens late in the context due to poor SAE performance, so this plot is an average of fractional change over the first 100 (non-BOS) tokens.

Most interventions modestly decrease downstream activation magnitudes (right panel), with the exception of layer 0 interventions increasing all downstream activation magnitudes and layer 11 typically having larger activation magnitudes after any intervention. Intervening at an early layer also tends to increase MSE loss downstream (this is a big problem for layer 9 regardless of which upstream layer is replaced, and is particularly bad for layer 3 when layer 2 is replaced). Most of these interventions also increase the  in downstream SAEs, and perhaps this lack of sparsity provides a hint that could help explain their poorer performance. We don’t want to dwell on these results because we don’t robustly understand what they tell us, but finding places where SAEs compose poorly (e.g., layer 9) could provide hints into the sorts of features that our SAEs are failing to learn. Maybe layer 9 performs some essential function that upstream SAEs are all failing to capture.

How does using SAE output in place of activations affect model performance on the Lambada Benchmark?

Next we examined the Lambada benchmark, which is a dataset of prompts with known correct answers. In an appendix, we examine the Children’s Book Test, which has long questions where the SAEs should perform poorly. We examined 1000 questions from the test split of Lambada. A sample prompt looks like (bolded word is the answer the model has to predict):

in the diagram , he pinched to the left of him , where the ground would be , and pulled up a trail of yellow that formed itself into a seat . argus watched as a small hatch opened up from the floor . an outline of the chair was imprinted in light , the seat and the back drawn into space by blue lasers . a tiny black cloud made of nanomachines streamed out of the hatch

Baseline gpt2-small accuracy on these 1000 questions is 26.6%. Below we plot the downstream accuracy and average CE loss across the 1000 questions when we replace each residual stream layer with the SAE activation reconstructions.

So we see that early layers do not degrade performance much but late layers degrade performance appreciably. Interestingly, we find slightly different trends here than Joseph found for the CE loss on Open WebText in his original post – for example, there he saw that CE loss was degraded monotonically by layer until layer 11 where it improved from layer 10 a bit. Here, layer 4 has a lower CE loss than layer 3, and while layer 10 and 11 each have about the same CE, layer 11 has worse performance on Lambada than layer 10. So this sort of benchmark does seem to provide a different type of information than plain CE loss on random tokens from the training text.

Next, we examine Lambada prompts in terms of how the downstream CE loss is changed by replacing residual stream activations with SAE reconstructions. In the left panel, we plot autoencoder MSE vs    and in the right panel we plot  vs  . Data points are colored according to whether the baseline model and SAE-intervened model answered correctly. Green points are where both are correct, red are where both are incorrect, and orange and blue are where only one is correct (see legend). The upper and right-hand sub-panels show the marginal probability distributions over each class of points.

All four datapoint classes seem to share roughly the same distribution in   and in the MSE of activation reconstruction. This again demonstrates that CE loss is not a perfect proxy for model performance on the Lambada task. Unsurprisingly, the four classes do have different distributions in terms of , which we expect because  relates directly to correctness of answer. It is interesting that there are some blue points where , meaning the correct answer’s probability decreased but the intervened model answered correctly anyways. We interpret these points as being prompts where other answers which the model had preferred suffered larger decreases in probability than the correct answer did, so despite a decrease in probability, the correct answer overall became the most probable.

In the plot below, we show a layer-by-layer and a point-by-point breakdown of errors. The y-axis is  of the correct answer caused by intervening. The left panel is a strip plot showing the distribution of  layer-by-layer. The right plot is a parallel coordinate plot showing how the autoencoders respond to each datapoint, and how  from each datapoint changes as they travel through the layers of the model. In the right panel, individual prompts are colored by their y-axis value at layer 0.

In the strip plot, we see that the mean performance trends downwards layer by layer, but the layers with the most variability are the layers a few from the beginning (3) and a few from the end (9). In the right panel, we see that there is no strong prompt-level effect. SAEs introduce quasi-random errors which can increase or decrease  of the correct answer depending on the layer.

Below we plot similar plots (strip, parallel-coordinate) but for the residual stream activations which are the inputs to the SAEs.

Here we see a clear layer and prompt effect. Later layers naturally have larger activations due to accumulation of attention/MLP outputs layer-by-layer. Prompts with large activations after embedding (see colorbar) tend to have large activations throughout all layers of the model and vice versa.

We can ask how metrics of interest vary with the magnitude of input activations. To avoid looking at layer effects, we normalize the mean  norms of residual stream activations at each layer by the minimum  value in the dataset within that layer. Below are scatterplots of MSE (upper left),  (upper right),   (lower left), and  (lower right) for each datapoint and at each layer. 

We first note the x-axis: each layer experiences little variance among the activations (maximum change of ~40% from the minimum value). We see no strong trends. There do seem to be weak trends in the left-hand panels: as residual stream activations grow, so too does the    norm of the SAE, and in turn the MSE may shrink somewhat. However, there is no strong evidence in our dataset that the performance of the SAE is intricately dependent on the magnitude of residual stream inputs.

Finally we examine the performance of each of the 24k features in the SAEs at a high level. For each datapoint, over the last 32 tokens of the context, we determine which features are active (). For each feature, we find all of the prompts where that feature is active, and we take the two measurements of “feature goodness” described in the experiment overview. In the top panel, we plot a histogram of the values of the mean  achieved by each feature when they are active. In the bottom panel, we plot a binary “score” for each feature, which determines if a feature is active more often on positive  answers or negative ones.

We see the same story in the top panel as in our first strip plot: the mean of the distribution shifts to the left (worse performance) at later layers in the model. But, interestingly, all SAEs at all layers have a small distribution of features that are to the right of the black line – which means, on average, they are active when  is positive! 

If we look at the bottom histogram, we see a much smaller family of features that have a positive score. This discrepancy between the top and bottom histograms means that the family of features which makes up the tail of ‘helpful’ features in the top histogram also have lots of activations where the model’s performance is weakly hurt by SAE intervention as well as a few activations where the model’s performance is strongly helped; these outliers push the mean into the positive in the top panel. Still, there is a very small family of features with positive scores, especially in early layers.

Perhaps nothing here is too surprising: model performance gets worse overall when we replace activations with SAE outputs, so the fact that most SAE features hurt model performance on a binary benchmark makes sense. We haven’t looked into any individual features in detail, but it would be interesting to investigate the small families of features that are helpful more often than not, or the features that really help the model (large ) some of the time.

Takeaways

  • Our SAEs perform poorly on contexts that are longer than their training context. This probably means that future SAEs just need to be trained on longer contexts, but we can’t rule out the fact that longer contexts could just be harder to train good SAEs for. Regardless, training SAEs on longer contexts is low-hanging fruit that should be done.
  • For short contexts, SAE performance decreases with model layer (presumably because there are more features that the model has stuffed into later layers). For too-long contexts, roughly the opposite is true, and earlier layers (~2-4) have the worst performance.
  • CE loss on random tokens does not necessarily map one-to-one to downstream model performance or CE loss on specific tasks. A set of benchmarks evaluating different model capabilities could help us understand what kinds of families of features are missing from our SAEs more robustly.
  • Some features in SAEs have activations which correlate with strong improvements in model performance (see benchmark histograms above). On the other hand, most features are active only when the model has poor performance. It’s unclear if these features themselves are problematic, or if the SAE is missing crucial features that those features should be paired with to recreate more perfect model performance.

Future work

This work raised more questions than it answered, and there’s a lot of work that we didn’t get around to in this blog post. We’d be keen to dig into the following questions, and we may or may not have the bandwidth to get around to this any time soon. If any of this sounds exciting to you, we’d be happy to collaborate!

  • How do SAEs change model performance on a well-documented circuit in gpt-2 like IOI?
  • We briefly looked into how an SAE intervention in an earlier layer affects downstream SAE reconstruction. We found in particular that the layer 9 SAE was negatively affected by any earlier intervention. Why is this? If we were to build a near-perfect layer 9 SAE, would this effect hold? If so, perhaps we could use feature activations in a later layer to learn more about what’s missing from early-layer SAEs.
    • In the case of residual-stream features like the ones we’re examining here, we can ask directly if these downstream effects are caused directly by bad earlier SAE outputs into the residual stream or if they affect the workings of subsequent attention and MLP layers.
  • We found a small family of features that helped model performance when they were active more often than not, and a slightly larger family of features which really helped model performance on some prompts. These outliers would be interesting to investigate in more detail. Joseph has a python function that lets you look at feature dashboards in your web browser if you just have the SAE and feature ID handy:
import webbrowser
def open_neuropedia(feature_id, layer = 0):
	path_to_html = f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feature_id}"
	print(f"Feature {feature_id}")
	webbrowser.open_new_tab(path_to_html)
	
open_neuropedia(3407, layer = 8)
  • In addition to studying the helpful features, it could be interesting to examine some of the worst SAE features – the ones that are always active when the model performance is harmed, for example. Looking at individual features is interesting, and also looking at also ablating these features and seeing how performance changes could be interesting.
  • On the topic of ‘good’ and ‘bad’ features (in terms of how they affect model performance): do these features have different distributions in interpretability space? It would be very unfortunate if the most interpretable features were also the worst performing ones.
    • Here we only measured frequency of feature activation, but Anthropic examines both the frequency and the strength of feature activations, and all of this information should be used in this analysis. 
  • In order to more robustly establish causality (we only looked at correlation) between features and good or poor model performance, it would be a good experiment to counterfactually ablate features and rank them by how much they hurt or help downstream performance.

Code

The code used to produce the analysis and plots from this post is available online in https://github.com/evanhanders/sae_validation_blogpost. See especially gpt2_small_autoencoder_benchmarks.ipynb

Acknowledgments

We thank Adam Jermyn, Logan Smith, and Neel Nanda for comments on drafts of this post which helped us improve the clarity of this post, helped us find bugs, and helped us find important results in our data. EA thanks Neel Nanda and Joseph Bloom for encouraging him to pull the research thread that led to this post. EA also thanks Stefan Heimersheim for productive conversations at EAG which helped lead to this post. Thanks to Ben Wright for suggesting we look at error propagation through layers. EA is also grateful to Adam Jermyn, Xianjung Yang, Jason Hoelscher-Obermaier, and Clement Neo for support, guidance, and mentorship during his upskilling.

Funding: EA is a KITP Postdoctoral Fellow, so this research was supported in part by grant NSF PHY-2309135 to the Kavli Institute for Theoretical Physics (KITP). JB is funded by Manifund Regrants and Private Donors, LightSpeed Grants and the Long Term Future Fund. 

Compute: Use was made of computational facilities purchased with funds from the National Science Foundation (CNS-1725797) and administered by the Center for Scientific Computing (CSC). The CSC is supported by the California NanoSystems Institute and the Materials Research Science and Engineering Center (MRSEC; NSF DMR 2308708) at UC Santa Barbara. Some computations were also conducted on the RCAC Anvil Supercomputer using NSF ACCESS grant PHY230163 and we are grateful to Purdue IT research support for keeping the machine running! 

Citing this post

@misc{anders_bloom_2024_gpt2saeacts,
   title = {Examining Language Model Performance with Reconstructed Activations using Sparse Autoencoders },
   author = {Anders, Evan AND Bloom, Joseph},
   year = {2024},
   howpublished = {\url{https://www.lesswrong.com/posts/8QRH8wKcnKGhpAu2o/examining-language-model-performance-with-reconstructed}},
}

Appendix: Children’s Book Test (CBT, Common Noun [CN] split)

As a part of this project we also looked at an additional benchmark. We chose to punt this to an appendix because a lot of the findings here are somewhat redundant with our Open WebText exploration above, but we include it for completeness.

We examine 1000 entries from the Common Noun split of the Children’s Book Text. These entries contain a 20-sentence context, and a 1-sentence fill-in-the-blank prompt. This context is long compared to what the models were trained on, so we expect the SAEs to perform fairly poorly here. An example of a CBT prompt is (bolded word is word that must be predicted):

Did n't you boast you were very sharp ? You undertook to guard our water ; now show us how much is left for us to drink ! ' ` It is all the fault of the jackal , ' replied the little hare , ` He told me he would give me something nice to eat if I would just let him tie my hands behind my back . ' Then the animals said , ` Who can we trust to mount guard now ? ' And the panther answered , ` Let it be the tortoise . ' The following morning the animals all went their various ways , leaving the tortoise to guard the spring . When they were out of sight the jackal came back . ` Good morning , tortoise ; good morning . ' But the tortoise took no notice . ` Good morning , tortoise ; good morning . ' But still the tortoise pretended not to hear . Then the jackal said to himself , ` Well , to-day I have only got to manage a bigger idiot than before . I shall just kick him on one side , and then go and have a drink . ' So he went up to the tortoise and said to him in a soft voice , ` Tortoise ! tortoise ! ' but the tortoise took no notice . Then the jackal kicked him out of the way , and went to the well and began to drink , but scarcely had he touched the water , than the tortoise seized him by the leg . The jackal shrieked out : ` Oh , you will break my leg ! ' but the tortoise only held on the tighter . The jackal then took his bag and tried to make the tortoise smell the honeycomb he had inside ; but the tortoise turned away his head and smelt nothing . At last the jackal said to the tortoise , ' I should like to give you my bag and everything in it , ' but the only answer the tortoise

Now let’s examine the same types of plots as we examined in the Lambada section.

Baseline gpt2-small accuracy on these 1000 questions is 41.9%. Below we plot accuracy and the average CE loss across the 1000 prompts.

Interestingly, we see the opposite trend as we saw for Lambada! For these long-context prompts, early layers manage about 0% accuracy! Later layers recover about ⅔ of the accuracy that baseline gpt2-small have, but are still quite a bit worse than baseline gpt2-small (and have significantly higher loss).

Below are scatterplot/histogram plots showing the distribution of features in MSE, , and log-prob space. 

The outlier points on the left-hand side are interesting, and seem to be out-of-distribution activations (see scatterplots below). Otherwise, there again does not seem to be a clear story between , CBT performance, and SAE MSE.

Next, we look at the layer effect:

The middle layers (2-5) are very high-variance and add most of the noise here, and they can totally ruin the model’s ability to get the right answer (or, in edge cases, make it much much more likely!). In the right-hand plot, there’s no strong prompt-level effect, but it seems like there might be something worth digging into further here. The gradient of light (at the bottom) to dark (at the top) seems more consistent than it did for the Lambada test above.

Residual stream activations follow a rather different trend than the Lambada ones did (and this involves no SAE intervention!):

The intermediate layers have a huge amount of variance in the residual stream magnitudes, and later layers can have smaller residual stream norms than layer 2. This is very counter-intuitive, and it must mean that early layers are outputting vectors into the residual stream which later layers are canceling out.

Next we look at scatter plots again:

There seem to be some trends here. At each layer, there is a ‘hook’ effect, where low-norm activations show a trend, in all but the bottom left panel, then there is a hook, and the trend changes. Strangely, the only layer where activations monotonically increase with input activation strength is layer 1; all other layers have a maximum MSE error at a lower activation strength and an overturning (MSE decreasing with input activation norm).

Another interesting distribution here: layers ~2-5 have a cloud of data points with very high  activations and very low downstream ! So despite the interesting trends seen in MSE, replacing model activations with SAE activations when the original residual stream has a large norm harms the model’s ability to correctly predict the answer token. It’s unclear if these high norm prompts are somehow out-of-distribution compared to the text corpus that gpt2-small and the SAEs were trained on.

Finally, we look at feature histograms:


While there are again a few features which lead to large (positive) , and thus there are a few features to the right of the 0 line in the top panel, there are zero features which are, on average, active only when the model’s probability of guessing correctly is improved. Interestingly, there are a few features in the layer 0 encoder which are almost always active and the model never sees an improvement when they’re active.

New Comment
16 comments, sorted by Click to highlight new comments since:

Thanks for doing this work -- I'm really happy people are doing the basic stress testing of SAEs, and I agree that this is important and urgent given the sheer amount of resources being invested into SAE research. 

For me, this was actually a positive update that SAEs are pretty good on distribution -- you trained SAE on length 128 sequences from OpenWebText, and the log loss was quite low up to ~200 tokens! This is despite its poor downstream use case performance.

I expected to see more negative results along the lines of your Lambada and Children's Book test results (that is, substantial degradation of loss, as soon as you go a tiny bit off distribution): 

I do think these results add on to the growing pile of evidence that SAEs are not good "off distribution" (even a small amount off distribution, as in Sam Marks's results you link). This means they're somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization. That doesn't mean they're useless -- e.g. it's plausible that SAEs could be useful for steering, mechanistic anomaly detection, or helping us do case analysis for heuristic arguments or proofs. 


As an aside, am I reading this plot incorrectly, or does the figure on the right suggest that SAE reconstructed representations have lower log loss than the original unmodified model?

For me, this was actually a positive update that SAEs are pretty good on distribution -- you trained SAE on length 128 sequences from OpenWebText, and the log loss was quite low up to ~200 tokens! This is despite its poor downstream use case performance.

 

Yes, this was nice to see. I originally just looked at context positions at powers of 2 (...64, 128, 256,...) and there everything looked terrible above 128, but Logan recommended looking at all context positions and this was a cool result! 

But note that there's a layer effect here. I think layer 12 is good up to ~200 tokens while layer 0 is only really good up to the training context size. I think this is most clear in the MSE/L1 plots (and this is consistent with later layers performing ok-ish on the long context CBT while early layers are poor).

This means they're somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization.

Yeah, agreed, which is a bummer because that's one thing I'd really like to see SAEs enable! Wonder if there's a way to change the training of SAEs to shift this over to on-distribution where they perform well.

As an aside, am I reading this plot incorrectly, or does the figure on the right suggest that SAE reconstructed representations have lower log loss than the original unmodified model?

Oof, that figure does indeed suggest that, but it's because of a bug in the plot script. Thank you for pointing that out, here's a fixed version: 

I've fixed the repo and I'll edit the original post shortly.

This means they're somewhat problematic for OOD use cases like treacherous turn detection or detecting misgeneralization.

 

I kinda want to push back on this since OOD in behavior is not obviously OOD in the activations. Misgeneralization especially might be better thought of as an OOD environment and on-distribution activations? 

I think we should come back to this question when SAEs have tackled something like variable binding with SAEs. Right now it's hard to say how SAEs are going to help us understand more abstract thinking and therefore I think it's hard to say how problematic they're going to be for detecting things like a treacherous turn. I think this will depend on how how representations factor. In the ideal world, they generalize with the model's ability to generalize (Apologies for how high level / vague that idea is). 

Some experiments I'd be excited to look at:

  • If the SAE is trained on a subset of the training distribution, can we distinguish it being used to decompose activations on those data points off the training distribution?
  • How does that compare to an SAE trained on the whole training distribution from the model, but then looking at when the model is being pushed off distribution? 

I think I'm trying to get at - can we distinguish:

  • Anomalous activations. 
  • Anomalous data points. 
  • Anomalous mechanisms. 

Lots of great work to look forward to!

problems

 

prompts*

@Evan Anders "For each feature, we find all of the problems where that feature is active, and we take the two measurements of “feature goodness" <- typo? 

Ah! That's the context, thanks for the clarification and for pointing out the error.  Yes "problems" should say "prompts"; I'll edit the original post shortly to reflect that. 

Has anyone tried training an SAE using the performance of the patched model as the loss function? I guess this would be a lot more expensive, but given that is the metric we actually care about, it seems sensible to optimise for it directly.

This is a good idea and is something we're (Apollo + MATS stream) working on atm.  We're planning on releasing our agenda related to this and, of course, results whenever they're ready to share.

I've heard this idea floated a few times and am a little worried that "When a measure becomes a target, it ceases to be a good measure" will apply here. OTOH, you can directly check whether the MSE / variance explained diverges significantly so at least you can track the resulting SAE's use for decomposition. I'd be pretty surprised if an SAE trained with this objective became vastly more performant and you could check whether downstream activations of the reconstructed activations were off distribution. So overall, I'm pretty excited to see what you get!

Am I right in thinking that your ΔCE metric is equivalent to the KL Divergence between the SAE patched model and the normal model?

After seeing this comment, if I were to re-write this post, maybe it would have been better to use the KL Divergence over the simple CE metric that I used. I think they're subtly different.

Per the TL implementation for CE, I'm calculating: CE =  where  is the batch dimension and  is context position. 

So CE =  for  the baseline probability and  the patched probability. 

So this is missing a factor of  to be the true KL divergence.

I think it is the same. When training next-token predictors we model the ground truth probability distribution as having probability  for the actual next token and  for all other tokens in the vocab. This is how the cross-entropy loss simplifies to negative log likelihood. You can see that the transformer lens implementation doesn't match the equation for cross entropy loss because it is using this simplification.

So the missing factor of  would just be  I think.

Oh! You're right, thanks for walking me through that, I hadn't appreciated that subtlety. Then in response to the first question: yep! CE = KL Divergence.

Could the problem with long context lengths simply be that the positional embedding cannot be represented by the SAE? You could test this by manually adding the positional embedding vector when you patch in the SAE reconstruction.

I think this is most of what the layer 0 SAE gets wrong. The layer 0 SAE just reconstructs the activations after embedding (positional + token), so the only real explanation I see for what it's getting wrong is the positional embedding.

But I'm less convinced that this explains later layer SAEs. If you look at e.g., this figure:

then you see that the layer 0 model activations are an order of magnitude smaller than any later-layer activations, so the positional embedding itself is only making up a really small part of the signal going into the SAE for any layer > 0 (so I'm skeptical that it's accounting for a large fraction of the large MSE that shows up there).

Regardless, this seems like a really valuable test! It would be fun to see what happens if you just feed the token embedding into the SAE and then add in the positional embedding after reconstructing the token embedding. I'd naively assume that this would go poorly -- if the SAE for layer 0 learns concepts more complex than just individual token embeddings, I think that would have to be the result of mixing positional and token embeddings?

My mental model is the encoder is working hard to find particular features and distinguish them from others (so it's doing a compressed sensing task) and that out of context it's off distribution and therefore doesn't distinguish noise properly. Positional features are likely a part of that but I'd be surprised if it was most of it.