Sometimes FLOP/s isn't the bottleneck for training models; e.g. it could be memory bandwidth. My impression from poking around with Nsight and some other observations is that wide SAEs might actually be FLOP/s bottlenecked but I don't trust my impression that much. I'd be interested in someone doing a comparison of this SAE architectures in terms of H100 seconds or something like that in addition to FLOP.
Did it seem to you like this architecture also trained faster in terms of wall-time?
Anyway, nice work! It's cool to see these results.
Thanks for the comment -- I trained TopK SAEs with various widths (all fitting within a single GPU) and observed wider SAEs take substantially longer to train, which leads me to believe that the encoder forward pass is a major bottleneck for wall-clock time. The Switch SAE also improves memory efficiency because we do not need to store all latents.
I'm currently working on implementing expert-parallelism, which I hope will lead to substantial improvements to wall-clock time.
Great work! Very excited to see work in this direction (In fact, I didn't know you were working on this, so I'd expressed enthusiasm for MoE SAEs in our recent list of project ideas published just a few days ago!)
Comments:
Following Fedus et al., we route to a single expert SAE. It is possible that selecting several experts will improve performance. The computational cost will scale with the number of experts chosen.
If there are some very common features in particular layers (e.g. an 'attend to BOS' feature), then restricting one expert to be active at a time will potentially force SAEs to learn common features in every expert.
+1 to similar concerns -- I would have probably left one expert always on. This should both remove some redundant features.
Hi Lee and Arthur, thanks for the feedback! I agree that routing to a single expert will force redundant features and will experiment with Arthur's suggestion. I haven't taken a close look at the router/expert geometry yet but plan to do so soon.
Hi Lee, if I may ask, when you say "geometric analysis" of the router, do you mean analysis of the parameters or activations? Are there any papers that perform the sort of analysis you'd like seen done? Asking from the perspective of someone who understands nns thoroughly but is new to mechinterp.
Both of these seem like interesting directions (I had parameters in mind, but params and activations are too closely linked to ignore one or the other). And I don't have a super clear idea but something like representational similarity analysis between SwitchSAEs and regular SAEs could be interesting. This is just one possibility of many though. I haven't thought about it for long enough to be able to list many more, but it feels like a direction with low hanging fruit for sure. For papers, here's a good place to start for RSA: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3730178/
For a batch with activations, we first compute vectors and . represents what proportion of activations are sent to each expert
Hi, I'm not exactly sure where f fits in here. In Figure 1/section 2.2, it seems like x is fed into the router layer, which produces a distribution over the N experts, from which the "best expert" is chosen. I'm not sure where the "proportion of activations" is in that process. To me that sounds like it's describing something that would be multiplied by x before it's fed into an expert, but I don't see that reflected in the diagram or described in section 2.2.
Nice work, these seem like interesting and useful results!
High level question/comment which might be totally off: one benefit of having a single, large, SAE neuron space that each token gets projected into is that features don't get in each other's way, except insofar as you're imposing sparsity. Like, your "I'm inside a parenthetical" and your "I'm attempting a coup" features will both activate in the SAE hidden layer, as long as they're in the top k features (for some sparsity). But introducing switch SAEs breaks that: if these two features are in different experts, only one of them will activate in the SAE hidden layer (based on whatever your gating learned).
The obvious reply is "but look at the empirical results you fool! The switch SAEs are pretty good!" And that's fair. I weakly expect what is happening in your experiment is that similar but slightly specialized features are being learned by each expert (a testable hypothesis), and maybe you get enough of this redundancy that it's fine e.g,. the expert with "I'm inside a parenthetical" also has a "Words relevant to coups" feature and this is enough signal for coup detection in that expert.
Again, maybe this worry is totally off or I'm misunderstanding something.
Thanks for your comment! I believe your concern was echoed by Lee and Arthur in their comments and is completely valid. This work is primarily a proof-of-concept that we can successfully scale SAEs by directly applying MoE, but I suspect that we will need to make tweaks to the architecture.
Can I ask what you used to implement the MOE routing? Did you use megablocks? I would love to expand on this research but I can't find any straightforward implementation of efficient pytorch MOE routing online.
Do you simply iterate over each max probability expert every time you feed in a batch?
wait a minute... could you just...
you don't just literally do this do you?
input = torch.tensor([
[1, 2],
[1, 2],
[1, 2],
]) # (bs, input_dim)
enc_expert_1 = torch.tensor([
[1, 1, 1, 1],
[1, 1, 1, 1],
])
enc_expert_2 = torch.tensor([
[3, 3, 0, 0],
[0, 0, 2, 0],
])
dec_expert_1 = torch.tensor([
[ -1, -1],
[ -1, -1],
[ -1, -1],
[ -1, -1],
])
dec_expert_2 = torch.tensor([
[-10, -10,],
[-10, -10,],
[-10, -10,],
[-10, -10,],
])
def moe(input, enc, dec, nonlinearity):
input = input.unsqueeze(1)
latent = torch.bmm(input, enc)
recon = torch.bmm(nonlinearity(latent, dec))
return recon.squeeze(1), latent.squeeze(1)
# not this! some kind of actual routing algorithm, but you end up with something similar
enc = torch.stack([enc_expert_1, enc_expert_2, enc_expert_1])
dec = torch.stack([dec_expert_1, dec_expert_2, dec_expert_1])
nonlinearity = torch.nn.ReLU()
recons, latent = moe(input, enc, dec, nonlinearity)
This must in some way be horrifically inefficient, right?
Just to close the loop on this one, the official huggingface transformers library just uses a for-loop to achieve MoE. I also implemented a version myself using a for loop and it's much more efficient than either vanilla matrix multiplication or that weird batch matmul I write up there for large latent and batch sizes.
The LessWrong Review runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2025. The top fifty or so posts are featured prominently on the site throughout the year.
Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?
Produced as part of the ML Alignment & Theory Scholars Program - Summer 2024 Cohort
0. Summary
To recover all the relevant features from a superintelligent language model, we will likely need to scale sparse autoencoders (SAEs) to billions of features. Using current architectures, training extremely wide SAEs across multiple layers and sublayers at various sparsity levels is computationally intractable. Conditional computation has been used to scale transformers (Fedus et al.) to trillions of parameters while retaining computational efficiency. We introduce the Switch SAE, a novel architecture that leverages conditional computation to efficiently scale SAEs to many more features.
1. Introduction
The internal computations of large language models are inscrutable to humans. We can observe the inputs and the outputs, as well as every intermediate step in between, and yet, we have little to no sense of what the model is actually doing. For example, is the model inserting security vulnerabilities or backdoors into the code that it writes? Is the model lying, deceiving or seeking power? Deploying a superintelligent model into the real world without being aware of when these dangerous capabilities may arise leaves humanity vulnerable. Mechanistic interpretability (Olah et al.) aims to open the black-box of neural networks and rigorously explain the underlying computations. Early attempts to identify the behavior of individual neurons were thwarted by polysemanticity, the phenomenon in which a single neuron is activated by several unrelated features (Olah et al.). Language models must pack an extremely vast amount of information (e.g., the entire internet) within a limited capacity, encouraging the model to rely on superposition to represent many more features than there are dimensions in the model state (Elhage et al.).
Sharkey et al. and Cunningham et al. propose to disentangle superimposed model representations into monosemantic, cleanly interpretable features by training unsupervised sparse autoencoders (SAEs) on intermediate language model activations. Recent work (Templeton et al., Gao et al.) has focused on scaling sparse autoencoders to frontier language models such as Claude 3 Sonnet and GPT-4. Despite scaling SAEs to 34 million features, Templeton et al. estimate that they are likely orders of magnitude short of capturing all features. Furthermore, Gao et al. train SAEs on a series of language models and find that larger models require more features to achieve the same reconstruction error. Thus, to capture all relevant features of future large, superintelligent models, we will likely need to scale SAEs to several billions of features. With current methodologies, training SAEs with billions of features at various layers, sublayers and sparsity levels is computationally infeasible.
Training a sparse autoencoder generally consists of six major computations: the encoder forward pass, the encoder gradient, the decoder forward pass, the decoder gradient, the latent gradient and the pre-bias gradient. Gao et al. introduce kernels and tricks that leverage the sparsity of the TopK activation function to dramatically optimize all computations excluding the encoder forward pass, which is not (yet) sparse. After implementing these optimizations, Gao et al. attribute the majority of the compute to the dense encoder forward pass and the majority of the memory to the latent pre-activations. No work has attempted to accelerate or improve the memory efficiency of the encoder forward pass, which remains the sole dense matrix multiplication.
In a standard deep learning model, every parameter is used for every input. An alternative approach is conditional computation, where only a small subset of the parameters are active depending on the input. This allows us to scale model capacity and parameter count without suffering from commensurate increases in computational cost. Shazeer et al. introduce the Sparsely-Gated Mixture-of-Experts (MoE) layer, the first general purpose architecture to realize the potential of conditional computation at huge scales. The Mixture-of-Experts layer consists of (1) a set of expert networks and (2) a routing network that determines which experts should be active on a given input. The entire model is trained end-to-end, simultaneously updating the routing network and the expert networks. The underlying intuition is that each expert network will learn to specialize and perform a specific task, boosting the overall model capacity. Shazeer et al. successfully use MoE to scale LSTMs to 137 billion parameters, surpassing the performance of previous dense models on language modeling and machine translation benchmarks.
Shazeer et al. restrict their attention to settings in which the input is routed to several experts. Fedus et al. introduce the Switch layer, a simplification to the MoE layer which routes to just a single expert. This simplification reduces communication costs and boosts training stability. By replacing the MLP layer of a transformer with a Switch layer, Fedus et al. scale transformers to over a trillion parameters.
In this work, we introduce the Switch Sparse Autoencoder, which combines the Switch layer (Fedus et al.) with the TopK SAE (Gao et al.). The Switch SAE is composed of many smaller expert SAEs as well as a trainable routing network that determines which expert SAE will process a given input. We demonstrate that the Switch SAE is a Pareto improvement over existing architectures while holding training compute fixed. We additionally show that Switch SAEs are significantly more sample-efficient than existing architectures.
2. Methods
2.1 Baseline Sparse Autoencoder
Let d be the dimension of the language model activations. The linear representation hypothesis states that each feature is represented by a unit-vector fi in Rd. Under the superposition hypothesis, there exists a dictionary of M≫d features (f1,f2,…,fM) represented as almost orthogonal unit-vectors in Rd. A given activation x can be written as a sparse, weighted sum of these feature vectors. Let w be a sparse vector in RM representing how strongly each feature is activated. Then, we have:
x=x0+M∑i=1wifi.A sparse autoencoder learns to detect the presence and strength of the features fi given an input activation x. SAE architectures generally share three main components: a pre-bias bpre∈Rd, an encoder matrix Wenc∈RM×d and a decoder matrix Wdec∈Rd×M. The TopK SAE defined by Gao et al. takes the following form:
z=TopK(Wenc(x−bpre))^x=Wdecz+bpreThe latent vector z∈RM represents how strongly each feature is activated. Since z is sparse, the decoder forward pass can be optimized by a suitable kernel. The bias term bpre is designed to model x0, so that x−bpre=∑Mi=1wifi. Note that Wenc and Wdec are not necessarily transposes of each other. Row i of the encoder matrix learns to detect feature i while simultaneously minimizing interference with the other almost orthogonal features. Column i of the decoder matrix corresponds to fi. Altogether, the SAE consists of 2Md+d parameters.
We additionally benchmark against the ReLU SAE (Conerly et al.) and the Gated SAE (Rajamanoharan et al.). The ReLU SAE applies an L1 penalty to the latent activations to encourage sparsity. The Gated SAE separately determines which features should be active and how strongly activated they should be to avoid activation shrinkage (Wright and Sharkey).
2.2 Switch Sparse Autoencoder Architecture
The Switch Sparse Autoencoder avoids the dense Wenc matrix multiplication. Instead of being one large sparse autoencoder, the Switch Sparse Autoencoder is composed of N smaller expert SAEs {Ei}Ni=1. Each expert SAE Ei resembles a TopK SAE with no bias term:
Ei(x)=WidecTopK(Wiencx)Each expert SAE Ei is N times smaller than the original SAE. Specifically, Wienc∈RMN×d and Widec∈Rd×MN. Across all N experts, the Switch SAE represents M features.
The Switch layer takes in an input activation x and routes it to the best expert. To determine the expert, we first subtract a bias brouter∈Rd. Then, we multiply by Wrouter∈RN×d which produces logits that we normalize via a softmax. Let σ denote the softmax function. The probability distribution over the experts p∈RN is given by:
p=σ(Wrouter(x−brouter))We route the input to the expert with the highest probability and weight the output by that probability to allow gradients to propagate. We subtract a bias before passing x to the selected expert and add it back after weighting by the corresponding probability:
i∗=argmax ipi^x=pi∗⋅Ei∗(x−bpre)+bpreIn total, the Switch Sparse Autoencoder contains 2Md+Nd+2d parameters, whereas the TopK SAE has 2Md+d parameters. The additional Nd+d parameters we introduce through the router are an insignificant proportion of the total parameters because M≫N.
During the forward pass of a TopK SAE, Md parameters are used during the encoder forward pass, kd parameters are used during the decoder forward pass and d parameters are used for the bias, for a total of Md+kd+d parameters used. Since M≫k, the number of parameters used is dominated by Md. During the forward pass of a Switch SAE, Nd parameters are used for the router, MNd parameters are used during the encoder forward pass, kd parameters are used during the decoder forward pass and 2d parameters are used for the biases, for a total of MNd+kd+Nd+2d parameters used. Since the encoder forward pass takes up the majority of the compute, we effectively reduce the compute by a factor of N. This approximation becomes better as we scale M, which will be required to capture all the safety-relevant features of future superintelligent language models. Furthermore, the TopK SAE must compute and store M pre-activations. Due to the sparse router, the Switch SAE only needs to store MN pre-activations, improving memory efficiency by a factor of N as well.
2.3 Switch Sparse Autoencoder Training
We train the Switch Sparse Autoencoder end-to-end. Weighting Ei∗(x−bpre) by pi∗ in the calculation of ^x allows the router to be differentiable. We adopt many of the training strategies described in Bricken et al. and Gao et al. with a few exceptions. We initialize the rows (features) of Wienc to be parallel to the columns (features) of Widec for all i. We initialize both bpre and brouter to the geometric median of a batch of samples (but we do not tie bpre and brouter). We additionally normalize the decoder column vectors to unit-norm at initialization and after each gradient step. We remove gradient information parallel to the decoder feature directions. We set the learning rate based on the 1√M scaling law from Gao et al. and linearly decay the learning rate over the last 20% of training. We do not include neuron resampling (Bricken et al.), ghost grads (Jermyn et al.) or the AuxK loss (Gao et al.).
The ReLU SAE loss consists of a weighted combination of the reconstruction MSE and a L1 penalty on the latents to encourage sparsity. The TopK SAE directly enforces sparsity via its activation function and thus directly optimizes the reconstruction MSE. Following Fedus et al., we train our Switch SAEs using a weighted combination of the reconstruction MSE and an auxiliary loss which encourages the router to send an equal number of activations to each expert to reduce overhead. Empirically, we also find that the auxiliary loss improves reconstruction fidelity.
For a batch B with T activations, we first compute vectors f∈RN and P∈RN. f represents what proportion of activations are sent to each expert, while P represents what proportion of router probability is assigned to each expert. Formally,
fi=1T∑x∈B1{i∗(x)=i}Pi=1T∑x∈Bpi(x)The auxiliary loss Laux is then defined to be:
Laux=N⋅N∑i=1fi⋅PiThe auxiliary loss achieves its minimum when the expert distribution is uniform. We scale by N so that Laux=1 for a uniformly random router. The inclusion of P allows the loss to be differentiable.
The reconstruction loss Lrecon is defined to be:
Lrecon=1T∑x∈B∥x−^x∥22Note that Lrecon∝d. Let α represent a tunable load balancing hyperparameter. The total loss Ltotal is then defined to be:
Ltotal=Lrecon+α⋅d⋅LauxWe optimize Ltotal using Adam (β1=0.9,β2=0.999).
3. Results
We train SAEs on the residual stream activations of GPT-2 small (d=768). In this work, we follow Gao et al. and focus on layer 8. Using text data from OpenWebText, we train for 100K steps using a batch size of 8192, for a total of ~820M tokens. We benchmark the Switch SAE against the ReLU SAE (Conerly et al.), the Gated SAE (Rajamanoharan et al.) and the TopK SAE (Gao et al.). We present results for two settings.
For a wide range of sparsity (L0) values, we report the reconstruction MSE and the proportion of cross-entropy loss recovered when the sparse autoencoder output is patched into the language model. A loss recovered value of 1 corresponds to a perfect reconstruction, while a loss recovered value of 0 corresponds to a zero-ablation.
3.1 Fixed Width Results
We train Switch SAEs with 16, 32, 64 and 128 experts (Figure 2, 3). The Switch SAEs consistently underperform compared to the TopK SAE in terms of MSE and loss recovered. The Switch SAE with 16 experts is a Pareto improvement compared to the Gated SAE in terms of both MSE and loss recovered, despite performing roughly 16x fewer FLOPs per activation. The Switch SAE with 32 experts is a Pareto improvement compared to the Gated SAE in terms of loss recovered. The Switch SAE with 64 experts is a Pareto improvement compared to the ReLU SAE in terms of both MSE and loss recovered. The Switch SAE with 128 experts is a Pareto improvement compared to the ReLU SAE in terms of loss recovered. The Switch SAE with 128 experts is a Pareto improvement compared to the ReLU SAE in terms of MSE, excluding when k=192. The k=192 scenario for the 128 expert Switch SAE is an extreme case: each expert SAE has 24576128=192 features, meaning that the TopK activation is effectively irrelevant. When L0 is low, Switch SAEs perform particularly well. This suggests that the features that improve reconstruction fidelity the most for a given activation lie within the same cluster.
These results demonstrate that Switch SAEs can reduce the number of FLOPs per activation by up to 128x while still retaining the performance of a ReLU SAE. Switch SAEs can likely achieve greater acceleration on larger language models.
3.2 FLOP-Matched Results
We train Switch SAEs with 2, 4 and 8 experts (Figure 4, 5, 6). The Switch SAEs are a Pareto improvement over the TopK, Gated and ReLU SAEs in terms of both MSE and loss recovered. As we scale up the number of experts and represent more features, performance continues to increase while keeping computational costs and memory costs (from storing the pre-activations) roughly constant.
Fedus et al. find that their sparsely-activated Switch Transformer is significantly more sample-efficient compared to FLOP-matched, dense transformer variants. We similarly find that our Switch SAEs are 5x more sample-efficient compared to the FLOP-matched, TopK SAE baseline. Our Switch SAEs achieve the reconstruction MSE of a TopK SAE trained for 100K steps in less than 20K steps. This result is consistent across 2, 4 and 8 expert Switch SAEs.
Switch SAEs speed up training while capturing more features and keeping the number of FLOPs per activation fixed. Kaplan et al. similarly find that larger models are more sample efficient.
4. Conclusion
The diverse capabilities (e.g., trigonometry, 1960s history, TV show trivia) of frontier models suggest the presence of a huge number of features. Templeton et al. and Gao et al. make massive strides by successfully scaling sparse autoencoders to millions of features. Unfortunately, millions of features are not sufficient to capture all the relevant features of frontier models. Templeton et al. estimate that Claude 3 Sonnet may have billions of features, and Gao et al. empirically predict that future larger models will require more features to achieve the same reconstruction fidelity. If we are unable to train sufficiently wide SAEs, we may miss safety-crucial features such as those related to security vulnerabilities, deception and CBRN. Thus, further research must be done to improve the efficiency and scalability of SAE training. To monitor future superintelligent language models, we will likely need to perform SAE inference during the forward pass of the language model to detect safety-relevant features. Large-scale labs may be unwilling to perform this extra computation unless it is both computationally and memory efficient and does not dramatically slow down model inference. It is therefore crucial that we additionally improve the inference time of SAEs.
Thus far, the field has been bottlenecked by the encoder forward pass, the sole dense matrix multiplication involved in SAE training and inference. This work presents the first attempt to overcome the encoder forward pass bottleneck. Taking inspiration from Shazeer et al. and Fedus et al., we introduce the Switch Sparse Autoencoder, which replaces the standard large SAE with many smaller expert SAEs. The Switch Sparse Autoencoder leverages a trainable router that determines which expert is used, allowing us to scale the number of features without increasing the computational cost. When keeping the width of the SAE fixed, we find that we can reduce the number of FLOPs per activation by up to 128x while still maintaining a Pareto improvement over the ReLU SAE. When fixing the number of FLOPs per activation, we find that Switch SAEs train 5x faster and are a Pareto improvement over TopK, Gated and ReLU SAEs.
Future Work
This work is the first to combine Mixture-of-Experts with Sparse Autoencoders to improve the efficiency of dictionary learning. There are many potential avenues to expand upon this work.
Acknowledgements
This work was supervised by Christian Schroeder de Witt and Josh Engels. I used the dictionary learning repository to train my SAEs. I would like to thank Samuel Marks and Can Rager for advice on how to use the repository. I would also like to thank Jacob Goldman-Wetzler, Achyuta Rajaram, Michael Pearce, Gitanjali Rao, Satvik Golechha, Kola Ayonrinde, Rupali Bhati, Louis Jaburi, Vedang Lad, Adam Karvonen, Shiva Mudide, Sandy Tanwisuth, JP Rivera and Juan Gil for helpful discussions.