Very exciting that JumpReLU works well with STE gradient estimation! I think this fixes one of the biggest flaws with TopK, which is that having a fixed number of latents k on each token is kind of wonky. I also like the argument in section 4 a lot - in particular the point about how this works because we're optimizing the expectation of the loss. Because of how sparse the features are, I wonder if it would reduce gradient noise substantially to use a KDE with state persisting across a few recent steps.
Thanks! I was worried about how well KDE would work because of sparsity, but in practice it seems to work quite well. My intuition is that the penalty only matters for more frequent features (i.e. features that typically do fire, perhaps multiple times, within a 4096 batch), plus the threshold is only updated a small amount at each step (so the noise averages out reliably over many steps). In any case, we tried increasing batch size and turning on momentum (with the idea this would reduce noise), but neither seemed to particularly help performance (keeping the number of tokens constant). One thing I suspect would further improve things is to have a per-feature bandwidth parameter, as I imagine the constant bandwidth we use is too narrow for some features and too wide for others; but to do this we also need to find a reliable way to adapt the bandwidth during training.
Did y'all do any ablations on your loss terms. For example:
1. JumpReLU() -> ReLU
2. L0 (w/ STE) -> L1
I'd be curious to see if the pareto improvements and high frequency features are due to one, the other, or both
You can't do JumpReLU -> ReLU in the current setup, as there would be no threshold parameters to train with the STE (i.e. it would basically train a ReLU SAE with no sparsity penalty). In principle you should be able to train the other SAE parameters by adding more pseudo derivatives, but we found this didn't work as well as just training the threshold (so there was then no point in trying this ablation). L0 -> L1 leads to worse Pareto curves (I can't remember off the top of my head which side of Gated, but definitely not significantly better than Gated) - it's a good question whether this resolves the high frequency features; my guess is it would (as I think Gated SAEs basically approximate JumpReLU SAEs with a L1 loss) but we didn't check this.
Re the choice of kernel, my intuition would have been that something smoother (e.g. approximating a Gaussian, or perhaps Epanechnikov) would have given the best results. Did you use rect
just because it's very cheap or was there a theoretical reason?
Good question! The short answer is that we tried the simplest thing and it happened to work well. (The additional computational cost of higher order kernels is negligible.) We did look at other kernels but my intuition, borne out by the results (to be included in a v2 shortly, along with some typos and bugfixes to the pseudo-code), was that this would not make much difference. This is because we're using KDE in a very high variance regime already, and yet the SAEs seem to train fine: given a batch size of 4096 and bandwidth of 0.001, hardly any activations (even for the most frequent features) end up having kernels that capture the threshold (i.e. that are included in the gradient estimate). So it's not obvious that improving the bias-variance trade-off slightly by switching to a different kernel is going to make that much difference. In this sense, the way we use KDE here is very different from e.g. using KDE to visualise empirical distributions, and so we need to be careful about how we transfer intuitions between these domains.
Nice work! I was wondering what context length you were using when you extracted the LLM activations to train the SAE. I could not find it in the paper but I might also have missed it. I know that OpenAI used a context length of 64 tokens in all their experiments which is probably not sufficient to elicit many interesting features. Do you use a variable context length or also a fixed value?
New paper from the Google DeepMind mechanistic interpretability team, led by Sen Rajamanoharan!
We introduce JumpReLU SAEs, a new SAE architecture that replaces the standard ReLUs with discontinuous JumpReLU activations, and seems to be (narrowly) state of the art over existing methods like TopK and Gated SAEs for achieving high reconstruction at a given sparsity level, without a hit to interpretability. We train through discontinuity with straight-through estimators, which also let us directly optimise the L0.
To accompany this, we will release the weights of hundreds of JumpReLU SAEs on every layer and sublayer of Gemma 2 2B and 9B in a few weeks. Apply now for early access to the 9B ones! We're keen to get feedback from the community, and to get these into the hands of researchers as fast as possible. There's a lot of great projects that we hope will be much easier with open SAEs on capable models!
Gated SAEs already reduced to JumpReLU activations after weight tying, so this can be thought of as Gated SAEs++, but less computationally intensive to train, and better performing. They should be runnable in existing Gated implementations.
Abstract: