The Singular Value Decompositions of Transformer Weight Matrices are Highly Interpretable
Please go to the colab for interactive viewing and playing with the phenomena. For space reasons, not all results included in the colab are included here so please visit the colab for the full story. A GitHub repository with the colab notebook and accompanying data can be found here. This post is part of the work done at Conjecture. TLDR If we take the SVD of the weight matrices of the OV circuit and of MLP layers of GPT models, and project them to token embedding space, we notice this results in highly interpretable semantic clusters. This means that the network learns to align the principal directions of each MLP weight matrix or attention head to read from or write to semantically interpretable directions in the residual stream. We can use this to both improve our understanding of transformer language models and edit their representations. We use this finding to design both a natural language query locator, where you can write a set of natural language concepts and find all weight directions in the network which correspond to it, and also to edit the network's representations by deleting specific singular vectors, which results in relatively large effects on the logits related to the semantics of that vector and relatively small effects on semantically different clusters Introduction Trying to understand the internal representations of language models, and of deep neural networks in general, has been the primary focus of the field of mechanistic interpretability, with clear applications to AI alignment. If we can understand the internal dimensions along which language models store and manipulate representations, then we can get a much better grasp on their behaviour and ultimately may be able to both make provable statements about bounds on their behaviour, as well as make precise edits to the network to prevent or enhance desired behaviours. Interpretability, however, is a young field where we still do not yet fully understand what the basic units of the