1 min read

5

This is a special post for quick takes by jake_mendel. Only they can create top-level comments. Comments here also appear on the Quick Takes page and All Posts page.
3 comments, sorted by Click to highlight new comments since:

I keep coming back to the idea of interpreting the embedding matrix of a transformer. It’s appealing for several reasons: we know the entire data distribution is just independent probabilities of each logit, so there’s no mystery about what features are data features vs model features. We also know one sparse basis for the activations: the rows of the embedding. But that’s also clearly not satisfactory because the embedding learns something! The thing it learns could be a sparse overbasis of non-token features, but the story for this would have to be different to the normal superposition story which involves features being placed into superposition by model components after they are computed (I find this story suss in other parts of the model too).
SAEs trained on the embedding do pretty well, but the task is much easier than in other layers because the dataset is deceptively small. Nonetheless if the error was exactly zero, this would mean that a sparse overbasis is certainly real here (even if not the full story). If the error were small enough we may want to conclude that this is just training noise. Therefore I have some experiment questions that would start this off:

  • Since the dataset of activations is so small, we can probably afford to do full basis pursuit (probably with some sort of weightings for token frequencies). How small does the error get? How does this scale with pretraining checkpoint? Ie is the model trying to reduce this noise? Presumably a UMAP of basis directions shows semantic clusters like with every SAE, implying there is more structure to investigate, but it would be super cool if that weren't the case.
  • How much interesting stuff is actually contained in the embedding? If we randomise the weights of the embedding (perhaps with rejection sampling to avoid rows being too high cosine sim) and pretrain gpt2 from scratch without ever updating the embedding weights, how much worse does training go? What about if we update one row of the embedding of gpt2 at a time to random and finetune?

If we find that 1) random embeddings do a lot worse and 2) basis pursuit doesn’t lead to error nodes that tend to zero over training, then we’re in business: the embedding matrix contains important structure that is outside the superposition hypothesis. Is matrix binding going on? Are circles common? WHAT IS IT

I hadn't seen that Wattenberg-Viegas paper before, nice.

Tangentially relevant: this paper by Jacob Andreas' lab shows you can get pretty far on some algorithmic tasks by just training a randomly initialized network's embedding parameters. This is in some sense the opposite to experiment 2.