OR PhD student at MIT working on interpretability.
Find out more here: https://wesg.me/
This was also my hypothesis when I first looked at the table. However, I think this is mostly an illusion. The sample means for rare tokens will have very high standard errors and so it is the case that rare tokens will have both unusually high average KL gap and unusually negative average KL gap mostly. And indeed, the correlation between token frequency and KL gap is approximately 0.
Yes this a good consideration. I think
This is a great comment! The basic argument makes sense to me, though based on how much variability there is in this plot, I think the story is more complicated. Specifically, I think your theory predicts that the SAE reconstructed KL should always be out on the tail, and these random perturbations should have low variance in their effect on KL.
I will do some follow up experiments to test different versions of this story.
Right, I suppose there could be two reasons scale finetuning works
The SAE-norm patch baseline tests (1) but based on your results, the scale factors vary within 1-2x so seems more likely your improvements come more from (2).
I don’t see your code but you could test this easily by evaluating your SAEs with this hook.
Yup! I think something like this is probably going on. I blamed this on L1 but this could also be some other learning or architectural failure (eg, not enough capacity):
Some features are dense (or groupwise dense, i.e., frequently co-occur together). Due to the L1 penalty, some of these dense features are not represented. However, for KL it ends up being better to nosily represent all the features than to accurately represent some fraction of them.
Huh I am surprised models fail this badly. That said, I think
We argue that there are certain properties of language that our current large language models (LLMs) don't learn.
is too strong a claim based on your experiments. For instance, these models definitely have representations for uppercase letters:
In my own experiments I have found it hard to get models to answer multiple choice questions. It seems like there may be a disconnect in prompting a model to elicit information which it has in fact learned.
Here is the code to reproduce the plot if you want to look at some of your other tasks:
import numpy as np
import pandas as pd
from transformer_lens import HookedTransformer
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model = HookedTransformer.from_pretrained(
model_name,
device='cpu',
torch_dtype=torch.float32,
fold_ln=True,
center_writing_weights=False,
center_unembed=True,
)
def is_cap(s):
s = s.replace('_', '') # may need to change for other tokenizers
return len(s) > 0 and s[0].isupper()
decoded_vocab = pd.Series({v: k for k, v in model.tokenizer.vocab.items()}).sort_index()
uppercase_labels = np.array([is_cap(s) for s in decoded_vocab.values])
W_E = model.W_E.numpy()
uppercase_dir = W_E[uppercase_labels].mean(axis=0) - W_E[~uppercase_labels].mean(axis=0)
uppercase_proj = W_E @ uppercase_dir
uc_range = (uppercase_proj.min(), uppercase_proj.max())
plt.hist(uppercase_proj[uppercase_labels], bins=100, alpha=0.5, label='uppercase', range=uc_range);
plt.hist(uppercase_proj[~uppercase_labels], bins=100, alpha=0.5, label='lowercase', range=uc_range);
plt.legend(title='token')
plt.ylabel('vocab count')
plt.xlabel('projection onto uppercase direction of W_E Mistral-7b')
This is further evidence that there's no single layer at which individual outputs are learned, instead they're smoothly spread across the full set of available layers.
I don't think this simple experiment is by any means decisive, but to me it makes it more likely that features in real models are in large part refined iteratively layer-by-layer, with (more speculatively) the intermediate parts not having any particularly natural representation.
I've also updated more and more in this direction.
I think my favorite explanation/evidence of this in general comes from Appendix C of the tuned lens paper.
This seems like a not-so-small issue for SAEs? If there are lots of half baked features in the residual stream (or feature updates/computations in the MLPs) then many of the dictionary elements have to be spent reconstructing something which is not finalized and hence are less likely to be meaningful. Are there any ideas on how to fix this?
For mechanistic interpretability research, we just released a new paper on neuron interpretability in LLMs, with a large discussion on superposition! See
Paper: https://arxiv.org/abs/2305.01610
Summary: https://twitter.com/wesg52/status/1653750337373880322
Here is Sonnet 3.6's 1-shot output (colab) and plot below. I asked for PCA for simplicity.
Looking at the PCs vs x, PC2 is kinda close to giving you x^2, but indeed this is not an especially helpful interpretation of the network.
Good post!