Awesome work with this! Definitely looks like a big improvement over standard SAEs for absorption. Some questions/thoughts:
In the decoder cos sim plot, it looks like there's still some slight mixing of features in co-occurring latent groups including some slight negative cos sim, although definitely a lot better than in the standard SAE. Given the underlying features are orthogonal, I'm curious why the Matryoshka SAE doesn't fully drive this to 0 and perfectly recover the underlying true features? Is it due to the sampling, so there's still some chance for the SAE to learn some absorption-esque feature mixes when the SAE latent isn't sampled? If there was no sampling and each latent group had its loss calculated each step (I know this is really inefficient in practice), would the SAE perfectly recover the true features?
It looks like Matryoshka SAEs will solve absorption as long as the parent feature in the hierarchy is learned before the child feature, but this doesn't seem like it's guaranteed to be the case. If the child feature happens to fire with much higher magnitude than the parent, then I would suspect the SAE would learn the child latent first to minimize expected MSE loss, and end up with absorption still. E.g. if a parent feature fires with probability 0.3 and magnitude 2.0 (expected MSE = 0.3 * 2.0^2 = 1.2), and a child feature fires with probability 0.15 but magnitude 10.0 (expected MSE = 0.15 * 10^2 = 15.0), I would expect the SAE would learn the child feature before the parent, and merge the parent representation into the child, resulting in absorption. In real LLMs, this might potentially never happen though so possibly not an issue, but could be something to look out for when training Matryoshka SAEs on real LLMs.
Thank you for sharing this! I clearly didn't read the original "Towards Monsemanticity" closely enough! It seems like the main argument is that when the weights are untied, the encoder and decoder learn different vectors, thus this is evidence that the encoder and decoder should be untied. But this is consistent with the feature absorption work - we see the encoder and decoder learning different things, but that's not because the SAE is learning better representations but instead because the SAE is finding degenerate solutions which increase sparsity.
Are there are any known patterns of feature firings where untying the encoder and decoder results in the SAE finding the correct or better representations, but where tying the encoder and decoder does not?
I'm not as familiar with the history of SAEs - were tied weights used in the past, but then abandoned due to resulting in lower sparsity? If that sparsity is gained by creating feature absorption, then it's not a good thing since absorption does lead to higher sparsity but worse interpretability. I'm uncomfortable with the idea that higher sparsity is always better since the model might just have some underlying features its tracking that are dense, and IMO the goal should be to recover the model's "true" features, if such a thing can be said to exist, rather than maximizing sparsity which is just a proxy metric.
The thesis of this feature absorption work is that absorption causes latents that look interpretable but actually aren't. We initially found this initially by trying to evaluate the interpretability of Gemma Scope SAEs and found that latents which seemed to be tracking an interpretable feature have holes in their recall that didn't make sense. I'd be curious if tied weights were used in the past and if so, why they were abandoned. Regardless, it seems like the thing we need to do next for this work is to just try out variants of tied weights for real LLM SAEs and see if the results are more interpretable, regardless of the sparsity scores.
That's an interesting idea! That might help if training a new SAE with tied encoder/decoder (or some loss which encourages the same thing) isn't an option. It seems like with absorption you're still going to get mixes of of multiple features in the decoder, and a mix of the correct feature and the negative of excluded features in the encoder, which isn't ideal. Still, it's a good question whether it's possible to take a trained SAE with absorption and somehow identify the absorption and remove it or mitigate it rather than training from scratch. It would also be really interesting if we could find a way to detect absorption and use that as a way to quantify the underlying feature co-occurrences somehow.
I think you're correct that tying the encoder and decoder will mean that the SAE won't be as sparse. But then, maybe the underlying features we're trying to reconstruct are themselves not necessarily all sparse, so that could potentially be OK. E.g. things like "noun", "verb", "is alphanumeric", etc... are all things the model certainly knows, but would be dense if tracked in a SAE. The true test will be to try training some real tied SAEs and seeing how interpretable the results look like.
Also worth noting, in the paper we only classify something as "absorption" if the main latent fully doesn't fire. We also saw cases which I would call "partial absorption" where the main latent fires, but weakly, and both the absorbing latent and the main latent have positive cosine sim with the probe direction, and both have ablation effect on the spelling task.
Another intuition I have is that when the SAE absorbs a dense feature like "starts with S" into a sparse latent like "snake", it loses the ability to adjust the relative levels of the various component features relative to each other. So maybe the "snake" latent is 80% snakiness, and 20% starts with S, but then in a real activation the SAE needs to reconstruct 75% snakiness and 25% starts with S. So to do this, it might fire a proper "starts with S" latent but weakly to make up the difference.
Hopefully this is something we can validate with toy models. I suspect that the specific values of L1 penalty and feature co-occurrence rates / magnitudes will lead to different levels of absorption.
My take is that I'd expect to see absorption happen any time there's a dense feature that co-occurs with more sparse features. So for example things like parts of speech, where you could have a "noun" latent, and things that are nouns (e.g. "dogs", "cats", etc...) would probably show this as well. If there's co-occurrence, then the SAE can maximize sparsity by folding some of the dense feature into the sparse features. This is something that would need to be validated experimentally though.
It's also problematic that it's hard to know where this will happen, especially with features where it's less obvious what the ground-truth labels should be. E.g. if we want to understand if a model is acting deceptively, we don't have strong ground-truth to know that a latent should or shouldn't fire.
Still, it's promising that this should be something that's easily testable with toy models, so hopefully we can test out solutions to absorption in an environment where we can control every feature's frequency and co-occurrence patterns.
I'd also like to humbly submit the Steering Vectors Python library to the list as well. We built this library on Pytorch hooks, similar to Baukit, but with the goal that it should work automatically out-of-the-box on any LLM on huggingface. It's different from some of the other libraries in that regard, since it doesn't need a special wrapper class, but works directly with a Huggingface model/tokenizer. It's also more narrowly focused on steering vectors than some of the other libraries.
I tried digging into this some more and think I have an idea what's going on. As I understand it, the base assumption for why Matryoshka SAE should solve absorption is that a narrow SAE should perfectly reconstruct parent features in a hierarchy, so then absorption patterns can't arise between child and parent features. However, it seems like this assumption is not correct: narrow SAEs sill learn messed up latents when there's co-occurrence between parent and child features in a hierarchy, and this messes up what the Matryoshka SAE learns.
I did this investigation in the following colab: https://colab.research.google.com/drive/1sG64FMQQcRBCNGNzRMcyDyP4M-Sv-nQA?usp=sharing
Apologies for the long comment, this might make more sense as its own post potentially. I'm curious to get others thoughts on this - it's also possible I'm doing something dumb.
The problem: Matryoshka latents don't perfectly match true features
In the post, the Matryoshka latents seem to have the following problematic properties:
- The latent tracking a parent feature contains components of child features
- The latents tracking child features have negative components of each other child feature
The setup: simplified hierarchical features
I tried to investigate this using a simpler version of the setup in this post, focusing on a single parent/child relationship between latents. This is like a zoomed-in version on a single set of parent/child features. Our setup has 3 true features in a hierarchy as below:
These features have higher firing probabilities compared to the setup in the original post to make the trends highlighted more obvious. All features fire with magnitude 1.0 and have a 20d representation with no superposition (all features are mutually orthogonal).
Simplified Matryoshka SAE
I used a simpler Matryoshka SAE that doesn't use feature sampling or reshuffling of latents and doesn't take the log of losses. Since we already know the hierarchy of the underlying features in this setup, I just used a Matryoshka SAE with a single inner SAE width of 1 latent to track the 1 parent feature, and the outer SAE width of 3 to match the number of true features. So the Matryoshka SAE sizes are as below:
The cosine similarities between the encoder and decoder of the Matryoshka SAE and the true features is shown below:
The Matryoshka decoder matches what we saw in the original post: the latent tracking the parent feature has positive cosine sim with the child features, and the latents tracking the child features have negative cosine sim with the other child feature. Our matryoshka inner SAE consisting of just latent 0 does track the parent feature as we expected though! What's going on here? How is it possible for the inner Matryoshka SAE to represent a merged version of the parent and child features?
Narrow SAEs do not correctly reconstruct parent features
The core idea behind Matryoshka SAEs is that in a narrower SAE, the SAE should learn a clean representation of parent features despite co-ooccurrence with child features. Once we have a clean representation of a parent feature in a hierarchy, adding child latents to the SAE should not allow any absorption.
Surprisingly, this assumption is incorrect: narrow SAEs merge child representations into the parent latent.
I tried training a standard SAE with a single latent on our toy example, expecting that the 1-latent SAE would learn only the parent feature without any signs of absorption. Below is the plot of the cosine similarities between the SAE encoder and decoder with the true features.
This single-latent SAE learns a representation that merges the representation of the child features into the parent latent, exactly how we saw in our Matryoshka SAE and in the original post's results! Our narrow SAE is not learning the correct representation of feature 0 as we would hope. Instead, it's learning feature 0 + weaker representations of child features 1 and 2.
Why does this happen?
Likely this reduces MSE loss compared with learning the actual correct representation of feature 0 on its own. When there's fewer latents than features, the SAE always has to accept some MSE error, and this behavior of merging in some of the child features into the parent latent likely reduces MSE loss compared with learning the actual parent feature 0 on its own.
What does this mean for Matryoshka SAEs?
This issue should affect any Matryoshka SAE, since the base assumption underlying Matryoshka SAEs is that a narrow SAE will correctly represent general parent features without any issues due to co-occurrence from specific child features. Since that assumption is not correct, we should not expect a Matryoshka SAE to completely fix absorption issues. I would expect that the topk SAEs from https://www.lesswrong.com/posts/rKM9b6B2LqwSB5ToN/learning-multi-level-features-with-matryoshka-saes would also suffer from this problem, although I didn't test that in this toy setting since topk SAEs are tricker to evaluating in toy settings (it's not obvious what K to pick).
It's possible the issues shown in this toy setting are more extreme than in a real LLM since the firing probabilities of the features may be higher than many features in a real LLM. That said, it's hard to say anything concrete about the firing probabilities of features in real LLMs since we have no ground truth data on true LLM features.