For comparing CE-difference (or the mean reconstruction score), did these have similar L0's? If not, it's an unfair comparison (higher L0 is usually higher reconstruction accuracy).
Good point. Firstly, the mean L0 between the experiment and the baseline is within a scaling factor of 2, so it's in a reasonably close range. I also added a new set of figures comparing the reconstruction score of one layer that have the closest match on L0 between the experiment group. Spoiler, the scores are still almost the same at the end of training. You can find it under Experiments-Performance Validation.
I want to mention that in my experience a factor of 2 difference in L0 makes a pretty huge difference in reconstruction score/L2 norm. IMO ideally you should compare pareto curves for each architecture or get two datapoints that have almost the exact same L0 if you want to compare two architectures.
The additional experiment under Experiment-Performance Verification (Figure 11) compares normalized_1
and baseline_1
on layer 5 which have almost identical . The result showed no observable difference.
Hi Hengyu! Really nice work here! I am wondering if you have released the pre-trained SAE for llama-2?
It would be good to benchmark the normalized and baseline SAEs using the standard metrics of patch loss and L0.
Patch loss is different to L2. It's the KL Divergence between the normal model and the model when you patch in the reconstructed activations at some layer.
Oh I see. I'll have to look into that cuz I used the AI-safety-foundation's implementation and they don't measure the KL divergence. That said, there is a validation metric called reconstruction score that measures how replacing activations change the total loss of the model, and the scores are pretty similar for the original and normalized.
there is a validation metric called reconstruction score that measures how replacing activations change the total loss of the model
That's equivalent to the KL metric. Would be good to include as I think it's the most important metric of performance.
I think these aren't equivalent? KL divergence between the original model's outputs and the outputs of the patched model is different than reconstruction loss. Reconstruction loss is the CE loss of the patched model. And CE loss is essentially the KL divergence of the prediction with the correct next token, as opposed to with the probability distribution of the original model.
Also reconstruction loss/score is in my experience the more standard metric here, though both can say something useful.
Reconstruction loss is the CE loss of the patched model
If this is accurate then I agree that this is not the same as "the KL Divergence between the normal model and the model when you patch in the reconstructed activations". But Fengyuan described reconstruction score as:
measures how replacing activations changes the total loss of the model
which I still claim is equivalent.
Hmm maybe I'm misunderstanding something, but I think the reason I'm disagreeing is that the losses being compared are wrt a different distribution (the ground truth actual next token) so I don't think comparing two comparisons between two distributions is equivalent to comparing the two distributions directly.
Eg, I think for these to be the same it would need to be the case that something along the lines
or
were true, but I don't think either of those are true. To connect that to this specific case, have be the data distribution, and and the model with and without replaced activations
on a separate note that could also be a crux,
measures how replacing activations changes the total loss of the model
quite underspecifies what "reconstruction score" is. So I'll give a brief explanation:
let:
then
so, this has the property that when the value is 0 the SAE is as bad as replacement with zeros and when it's 1 the SAE is not degrading performance at all
It's not clear that normalizing with makes a ton of sense, but since it's an emerging domain it's not fully clear what metrics to use and this one is pretty standard/common. I'd prefer if bits/nats lost were the norm, but I haven't ever seen someone use that.
I don't think is very informative here, as it's highly impacted by the input batch. Both the raw and have large variances at different verification steps, and since we mainly care about how good our reconstruction is compared with the original, I think the reconstruction score is good as is. I also don't follow why the noisiness of leads to showing .
TL;DR
Sparse autoencoders (SAEs) presents us a promising direction towards automating mechanistic interpretability, but it not without flaws. One known issue of the original sparse autoencoders is the feature suppression effect which is caused by the conflict between the L2 and L1 loss and the unit norm constraint on the SAE decoders. This effect in theory will be more evident when we have inputs that have high norms. Another observation is that training SAEs on multiple layers simultaneously results in inconsistent L0 norms for feature activations across layers: in some layers, L0 has scale of 102 , while in some other layers it has a scale of 101. Moreover, the residual states that's inputed to the SAEs for training also have different norms across layers. Hence, I argue that the current SAE architecture is not robust against inputs of varying norms, which is commonly the case in modern LLMs. In this post, I a modified SAE architecture, namely Normalized Sparse Autoencoder (NSAE), and gave a theoretical proof that it will not have the feature suppression problem. I then conducted experiments to verify the effectiveness of the proposed method, which showed that:
I then further investigated the learned feature dictionaries and identified 3 types of feature vectors: the correction vector, the pillar vector, and the direction vector. I then concluded this post with discussion on the limitations of NSAEs and gave my suggestions on future directions.
Introduction
Training Sparse Autoencoders (SAEs) on the residual states of pretrained models is a recently proposed method in mechanistic interpretability to tackle the problem of superposition. This method is scalable and unsupervised, making it promising for auto-interpretability research.
More specifically, a SAE contains an encoder and a decoder. It is trained to generate sparse feature activations from the original residual states of a source model through the encoder, and reconstruct the residual state through a decoder. It is expected that by training the SAE with a large set of activations jointly optimizing for a sparsity loss on the feature activations and a L2 reconstruction loss, the model can learn to decompose residual states into monosemantic feature vectors that are more interpretable.
In this post, I identified a flaw in the original SAE implementation, namely inconsistency of the L1 loss across layers, and proposed a method to mitigate this problem. With the new method, we can significantly decrease the correlation between the norm of the source model's residual activations and the L0 norm of the feature activations, making the training process more robust and controllable. The code is available on GitHub (notice that you should use the
dev
branch instead of others).Motivations
Feature suppression is a known problem for SAEs. It originated from a conflict between the L1 sparsity loss and the L2 reconstruction loss, as the reconstruction's norm is correlated with L1, and the SAE model learns to generate a reconstruction with smaller norm for a better L1 loss. This is not desirable, as we would like the reconstruction to best correspond to the original input activations. Therefore, finding a way to disentangle the input norms from L1 and L2 is beneficial.
Also, in my personal experiments with training SAEs using this implementation from the AI Safety Foundation, I observed an inconsistency of the L1 sparsity loss across layers:
The above two figures are the L1 losses of two different layers from the same training run, but the scale of L1 has a 10× difference.
Moreover, the sparsity measured by L0 is also vastly different across layers:
I argue that this is also undesirable, as we introduced the L1 coefficient α in attempt to control the balance between the L1 and L2 loss across layers. Ideally, α should have consistent control across layers, which is not the actual case.
Moreover, there is an inconsistency of the norms of the source model's residual states across layers. We can plot the distribution of residual states[1] norms in GPT-2 small across layers:
It is obvious that the mean and variance of the norms differ across layers.
This effect is common among LLMs, and we can find similar effects in more recent models like LLaMA-2 and Gemma:
This provides some evidence that the inconsistency of input norms might have caused the undesirable behaviors in SAEs. Thus, I will conduct a theoretical analysis in the next section to further illustrate this problem.
Theoretical Analysis
Definitions
With these observations in mind, let's do a theoretical analysis on this loss to see why they might have happened.
Formally, a SAE can be defined as the following:
Encoder(x)=ReLU(Wex+be)=c
Decoder(c)=Wdc
x′=SAE(x)=Decoder(Encoder(x))
We denote the output of encoder as the feature activation c
The loss function for optimization is defined as
L1=||c||1
L2=||x′−x||2
L(x′,x)=αL1+L2=α||c||1+||x′−x||2
where the L1 coefficient α∈R+ is a hyperparameter of the user's choice and ||⋅||k is the k-norm of a given vector.
We set another hyperparameter expansion factor k∈N+ and denote the source model's residual dimension as n. Then we can define m=kn and we have x,x′∈Rn, c,be∈Rm, We∈Rm×n, and Wd∈Rn×m.
In the original implementation, the authors constrained the decoder to have unit norm column vectors, so that during the optimization process the model won't minimize the L1 loss by increasing the column norms of the decoder and learn to generate dense feature activation of small L1. This design choice lead to a potential flaw in the method and will be discussed in a later section of this post.
The Effect of Input Norms on Feature Suppression
The authors who identified feature suppression have provided a nice theoretical analysis in the Feature Suppression section, but for the comprehensiveness of this post, I will conduct a similar analysis using the terms defined in this post.
We first consider the extreme case where an input x has a feature activation c that only has one positive entry i, with all other entries equal to 0. Then we have x′=Decoder(c)=Wdc=ciwd,i where wd,i is the i-th column vector of Wd. Since Wd is column normal, we must have ||x′||2=||ciwd,i||2=ci=||c||1.
More generally, I will show that when ||c||1 is sparse, we also have ||c||1≈||x′||2.
Define I={i:ci≠0} the index set of all nonzero entries in the feature activation. Then we assume that the feature vectors in the set {wd,i:i∈I} are (almost) mutually orthogonal[2], which is ∀i,j∈I,i≠j,wd,i⋅wd,j≈0. By the constraint that the decoder have unit norm, which is ∀i∈I,wd,i⋅wd,i=1, we have
||x′||2=√x′⋅x′=√∑i∈Iciwd,i⋅∑i∈Iciwd,i=√∑i∈I∑j∈Icicjwd,i⋅wd,j≈√∑i∈Ic2i=||c||2In the case of sparse c, we have ||c||1⪆||c||2≈||x′||2.
Then our loss function becomes the following:
L=α||x′||2+||x′−x||2
If we attempt to minimize this loss, there is always a tradeoff between the reconstruction accuracy and the norm of the reconstruction. In most cases, the model will learn to construct x′ that's close enough to x but slightly smaller than x to achieve low losses in both terms.
The Effect of Input Norms on the Inconsistency of L0 Across Layers
Here, we make the similar assumption that when ||c||1 is sparse, we have ||c||1≈||x′||2.
For the L2 term, we have
||x′−x||2=||x′||22+||x||22−2x′⋅x=||x′||22(1+||x||22||x′||22−2x′⋅x||x′||22)
At first glance, this might not be obvious, but if our reconstruction x′ is similar enough to x, we can take ||x′||2≈||x||2 [3]and the equation simplifies to
||x′−x||2=||x||22(1+1−2x′⋅x||x||2⋅||x′||2)=2||x||22(1−cos(x′,x))
Now we can rewrite our loss:
L(x′,x)=α||x′||2+2||x′||22(1−cos(x′,x))
Notice that, if 1−cos(x′,x) is in a relatively fixed scale, then the first term has a scale of ||x′||2 while the second term has a scale of ||x′||22. Then , given a fixed α, if we have a larger ||x||2, the loss term will bias towards the second term, which agrees with the observation I had earlier: the source model's residual states in deeper layers have larger norms than shallower layers, and the L1 loss was significantly higher in deeper layers as the loss was dominated by the larger L2 term.
Normalizing SAEs
After such analysis, it natural for us to ask: is there a way to solve these problems?
My answer is yes!
Here, I propose an architectural modification to the original SAE architecture, which I have named the Normalized Sparse Autoencoder (NSAE).
Architecture
The modified architecture is defined as the following:
c=tanh(ReLU(Wex+be+ϵ))
NSAE(x)=Wdc
In this definition, c is the new feature activation, and Wd is no longer constrained to unit norm. A Gaussian error term ϵ is introduced to regularize the feature activation, which is sampled from N(0,σ) for some hyperparameter σ.
The introduction of tanh normalizes every entries of c to the range of [0,1). The benefits of doing this are threefolds:
The Gaussian noise term is also essential in this architecture. Without it, the model can learn to minimize L1 by learning to map to very small positive values in the feature activation space and learn decoders with extremely large column norms.
To show why adding Gaussian noise solves this problem, I plotted the activation in the following figure:
From the figure, we can see that when the inputs are small, the output of tanh(ReLU) will be relatively sensitive to the input, and adding Gaussian noise can significantly perturb small feature activations. On the contrary, larger inputs to the activation function are much more robust to perturbation, as they all maps to similar values close to 1. Hence, this perturbation forces the model to learn to generate feature activations that are either strictly 0 or close to 1, which makes ||c||1 behave even more like ||c||0, especially when we set σ to be large.
Loss
We also have to redefine the loss as follows:
Xi={x:x is in the i-th layer of the input batch}
βi=α⋅mean(Xi)2
Li(x′,x)=βi||c||1+||x′−x||2
We introduced the additional step of scaling α by the square of the mean of the input norm of one layer. This is because ||x′−x||2≈2||x′||22(1−cos(x′,x)). If we assume that the best an optimizer can do is to achieve a fixed cosine similarity between x and x′ without the L1 constraint, then we can treat the (1−cos(x′,x)) term as a constance, so the L2 loss is of the scale ||x′||22≈||x||22, while ||c||1≈||c||0 which should be constant across layers. Therefore, we can manually scale the L1 loss to match the scale of the L2 loss. Another way to scale the loss is by using the actual ||x||22 of the given sample. Theoretically this might cause the model to overfit to inputs of large norms, but for the conciseness of this post, I will leave this problem for future work to investigate, and only use the mean normalization for all the following experiments.
Experiments
I trained two groups of SAEs, one baseline and one experiment, on all layers of GPT2, and each group contains 2 training runs trained on 100,000,000 activations. These four runs used different sets of L1 coefficient and learning rate, and the baseline used the original SAE while the experiment used the normalized SAE. I will use "the experiment group" and "the normalized group" interchangeably.
Feature Suppression is Suppressed in Normalized SAE
To investigate feature suppression, I added a new verification metric that measures the ratio between the norm of reconstructions and norm of source activations. Here is this measure during training:
Clearly, the normalized group has significant higher score on feature suppression than the experimental group, and that score is very close to one. Considering the fact that this NSAE didn't fully converge as it only went through 200M training examples, and there is not a sign of this score to flatten, I claim that NSAEs have less to none feature suppression.
Normalizing L1 Removes the Correlation Between Input Norm and L0
To investigate the effect of normalization, I collected the L0 norms of different layers during the end of training and plotted them against the mean input norms of the layer:
The red and blue datapoints are from the baseline group whereas the cyan and purple datapoints are from the experiment. We can fit lines to these datapoints to find linear relationships between the mean input norm and the mean L0 norm of the feature activations. Although the fitting is not good, the fitted lines still show a rough positive linear correlation between the mean input norm and the feature activation L0 norm in the baseline. In contrast, the two normalized samples did not exhibit a statistical significant positive linear relationship between input norm and L0.
This linear fit definitely does not look satisfactory, and I further investigated the reasons behind it. I plotted the normalized group's L0 against layer index, and here is what it looks like:
I conjecture that L0 in the normalized group reflects a level of discreteness of the activations of the source model, as it exhibit an increase-then-decrease pattern. In the source model, earlier activations are more discrete as they originated from discrete input embeddings, and as deeper activations might be less discrete as they aggregate information. In the last layers, as the model has to make the next token prediction as accurate as possible, the activations might become more discrete again for better next-token decoding since the decoding layer is discrete. This discreteness might also be positively correlated with the monosemanticity of the activations, as more discrete activations are often more interpretable. I will not verify this conjecture in this post due to length considerations, and I welcome other to study this problem.
L1 Agrees with L0 Better
To investigate the agreement between L1 and L0, I plot the mean L0 and L1 of the feature activations for both groups:
Clearly, the cyan and purple solid lines (which are L1) are much closer to their corresponding dashed lines (L0) than the baselines, indicating better agreement between L0 and L1.
Performance Validation
To validate that the normalization did not heavily impact performance, I present the reconstruction score metric. I first calculate the loss of no intervention, zero intervention (replacing hidden states in one layer with zero vectors), and reconstruction intervention (replacing hidden states in one layer with reconstructed vectors from SAE), and I will denote them as Lclean, Lzero, and Lreconstruction, respectively. Then, the score is calculated by
Sreconstruction=Lzero−LreconstructionLzero−Lclean
Since we expect Lzero to be higher than Lclean, and we want Lreconstruction to be close to Lclean, so higher score is better, and we expect a value close to 1. The score during training is show below:
There is no observable difference between the normalized group and the baseline group except that the normalized group's score seems slightly more stable during training, indicating that the normalization did not heavily impact performance but might improved training stability.
Since the mean reconstruction score is heavily impacted by the sparsity of the feature activation, I also compared a layer where the L0 of the baseline and experiment group best agrees with each other:
Still, there is not an observable difference between the experimental group and the baseline after convergence. This provided further evidence that the normalization did not have a observable negative impact on the performance of SAEs.
NSAE Statistics
To further investigate what the new SAE has learned, I did some statistical analysis on the NSAE feature dictionary from the first run. For comparison, I used the original SAE trained in the first baseline run.
I first analyzed the norm distribution of the feature vectors along the layers:
Interestingly, a large proportion of feature vectors have norms in the range of (0,0.5), which might indicate that these vectors are small correction vectors that are added to a bigger vector to make the prediction as close as possible. In contrast, I hypothesize that feature vectors of norms that have high mean activation norm should have good interpretability as they represent general directions to the reconstruction. Hence, I will name these vectors as the pillar vectors.
Next, I calculate the distribution of cosine similarity of the feature dictionary:
From the figure, it's obvious that the cosine similarity distribution of NSAE and SAE are very similar except that in NSAE there are some cosine similarity very close to one. my hypothesis to these vectors is that in NSAE, there are some direction vectors that appears frequently in different norms in the decomposition of source model activations, so that NSAE have to learn these vectors of the same direction in different norms.
A natural question to ask is that: do pillar vectors and direction vectors overlap? To answer this, I picked the top-100 vectors (in terms of norm) of each layer from the feature dictionary as a set of pillar vectors and calculated their cosine similarity, and here is the distribution:
Since the are little to none vectors that have very high cosine similarity, there is minimal overlap between pillar vectors and direction vectors.
As this post is already pretty long, I will leave a more comprehensive analysis on the learned feature dictionary to a future post and conclude this post.
Discussion
Limitations
The normalization did not come without cost. NSAEs generally have slightly higher reconstruction losses compared with the original, and it takes longer for NSAE to converge, as shown in the following figure:
I suspect the reason of this is because NSAE learns a non-unit norm dictionary, and this dictionary have to capture all the norm information with a fixed size, whereas the original SAE can learn directions and add norm information through the feature activations.
Another metric that I don't know how to interpret is the neural activity. In NSAE, the neural activity are significantly higher than the original SAE:
Lastly, the experiments conducted are relatively small in scale due to limitations in compute. Moreover, due to the change of the loss function, it's hard to directly match the scales of L0 between the baseline and the experiment group.
Future Work
I suggest future work to go along the following directions:
Appendix
Hyperparameters
I varied the hyperparameters
l1_coefficients
and the optimizer learning ratelr
. For the two normalized groups, I also set the standard deviation of the Gaussian noise σ.Related Work
Riggs et. al. proposed to use Sparse Autoencoders (SAEs) to discover interpretable features in large language models. Later, Wright et. al. identified the Feature Suppression effect in SAEs and argued that the L1 loss induced smaller feature activations that harmed reconstruction performance. Wes Gurnee observed that the reconstruction errors in SAEs are empirically pathological, and compared different norm-aware interventions to the source model's inference. Results show that replacing the original residual state with SAE significantly changed the model's predictions, especially in deeper layers.
In this and the following examples, I used the residual states from the MLP layer.
This is a reasonable assumption, as data in Figure 13 (baseline) show that most feature vector pairs in the original sparse autoencoder have cosine similarities in the range of (−0.2,0.2).
Empirically, ||x′||2≈0.9||x||2, which is close enough for our analysis.
For computational efficiency, I randomly sampled 1,000,000 features from the cosine similarity matrix.
L0 collected from step=3000. Input norm sampled from a relatively small sample of random text. This text is the same as the text used to generate figure 3, 4a, and 4b.