Transformers Represent Belief State Geometry in their Residual Stream
Produced while being an affiliate at PIBBSS[1]. The work was done initially with funding from a Lightspeed Grant, and then continued while at PIBBSS. Work done in collaboration with @Paul Riechers, @Lucas Teixeira, @Alexander Gietelink Oldenziel, and Sarah Marzen. Paul was a MATS scholar during some portion of this work. Thanks to Paul, Lucas, Alexander, Sarah, and @Guillaume Corlouer for suggestions on this writeup. Update May 24, 2024: See our manuscript based on this work Introduction What computational structure are we building into LLMs when we train them on next-token prediction? In this post we present evidence that this structure is given by the meta-dynamics of belief updating over hidden states of the data-generating process. We'll explain exactly what this means in the post. We are excited by these results because * We have a formalism that relates training data to internal structures in LLMs. * Conceptually, our results mean that LLMs synchronize to their internal world model as they move through the context window. * The computation associated with synchronization can be formalized with a framework called Computational Mechanics. In the parlance of Computational Mechanics, we say that LLMs represent the Mixed-State Presentation of the data generating process. * The structure of synchronization is, in general, richer than the world model itself. In this sense, LLMs learn more than a world model. * We have increased hope that Computational Mechanics can be leveraged for interpretability and AI Safety more generally. * There's just something inherently cool about making a non-trivial prediction - in this case that the transformer will represent a specific fractal structure - and then verifying that the prediction is true. Concretely, we are able to use Computational Mechanics to make an a priori and specific theoretical prediction about the geometry of residual stream activations (below on the left), and then show that this prediction holds t
Why do you think finding the true features should make the network look sparse and modular?