This is a linkpost for the recent paper "White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?"

It would obviously be lovely to have arguments about AI Safety techniques that included actual formal mathematical proofs of safety, or even just of proofs of some of the premises that safety relies upon. A number of people have been pursuing this goal in agent foundations, but a rather different (and possibly complementary) avenue now appears to have opened up.

The paper "White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?" (from a team led by Prof. Li Ma mostly at U.C. Berkeley, plus a variety of other well-known institutions) claims to prove that they have derived a transformer-like architecture[1] with the property that the trained model optima found by the Stochastic Gradient Descent (SGD) training process applied to it are themselves unrolled alternating optimization processes that optimize a new information compression metric, which they describe mathematically and name sparse rate reduction. If this paper is correct (sadly I'm not enough of an expert in Learning Theory to fully assess it, but the paper appears plausible on a couple of readings, includes some solid-looking experimental results from training their new architecture, builds on an extensive series of earlier papers including some from well-known authors, and is primarily from a well-respected institution), then that would mean that there is a mathematically tractable description of what models trained using this architecture are actually doing. If so, rather than just being an enormous inscrutable black-box made of tensors that we don't understand, their behavior could be reasoned about mathematically in terms of a description of what the resulting trained mesa-optimizer is actually optimizing — or as the paper's title puts it, the model is now a "white-box".[2] That in turn could well have huge implications for the feasibility of doing mathematical analysis of the AI safety of machine-learned models.

The paper includes experimental results comparing the new architecture to standard transformers and diffusion models for both image and text tasks, which seem to support that (without extensive fine-tuning) they are comparable (suggesting the theory they're derived from is on the right track) and even in some cases somewhat more parameter-efficient, but not otherwise dramatically improved. [So I rather suspect we will sooner or later get a proof that the architectural differences of standard transformer models from this theoretically-derived transformer-like architecture are in some sense not very important, and that existing transformers are mostly approximating the same thing, if possibly less well or less efficiently than this principled new architecture.]

The authors also argue that they would expect of the internal features of models trained using this architecture to be particularly sparse, orthogonal, axis-aligned,[3] and thus easily interpretable, as well as having a good theoretical understanding of how and why the model is generating and then using these features. If true, this claim sounds like its could have large implications for Mechanistic Interpretability (and likely also for AI safety approaches based on it, such as Activation Engineering, Eliciting Latent Knowledge and Just Retarget the Search).

Is anyone with the appropriate mathematical or Mech. Interpretability skills looking at this paper, and indeed at the Interpretability properties of trained models using the authors' principled variant on transformer architecture? This paper's claims sound like they could have huge implications for AI Safety and Alignment — I'd love to hear in the comments what people think both of the paper's arguments, and of its implications for reducing AI risk.

To me, this paper feels like it is a significant step in a process that in recent years has been gradually transforming machine learning from a field like alchemy that relies almost entirely on trial-and-error discovery to one more like chemistry or even engineering that has solid theoretical underpinnings. For AI safety, this seems a very welcome development — I hope it's completed before we get Transformational AI. Helping to ensure this sounds like it could be a valuable AI Alignment funding goal.

  1. ^

    Their architecture has attention-like layers that are slightly simplified compared to a standard transformer: each attention head, rather than three separate Query, Key, and Value parameters, has just a single parameter that combines these roles (thus reducing the parameter count for a network with the same number of layers and layer sizes). There are also some minor differences in the details of how the alternating layers are interleaved between the first and second halves of the layer stack, explicitly making this an encoder-decoder (i.e. autoencoder) architecture. Interestingly, the authors also claim to have proven that diffusion models are in some meaningful sense equivalent to the second half of the layers in their modified-transformer-architecture models, thus theoretically unifying two very effective but ap[apparently very different architectures.

  2. ^

    The paper builds on work published last year by Prof. Li Ma and a different team of students, part of a line of research he has been pursuing for a while, similarly constructing a principled "white-box" version of convolutional neural networks (CNNs) [which are primarily used for image processing]: ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction. Sadly this mathematically-derived ReduNet architecture was not in practice as effective as the previous CNN and RNN architectures that it was attempting to replace, though it did have a number of striking advantages in training and sample efficiency — one could directly construct a moderately good approximation to the optimum network from as little as a single sample of each class of object, without any training or use of backpropagation (!), and then further improve it by training in the conventional way. As the new paper explains in its detailed Introduction, people have been attempting to do things along these lines for about a decade, but until this paper, various architectures constructed using principled mathematical reasoning have not in practice been as effective on real-world tasks as previously-known ad-hoc architectures discovered by trial-and-error — suggesting that the mathematical inspiration was missing something important. The paper's authors have now combined two different previous approaches to doing this ("sparsity" and "rate reduction"), and the resulting "sparse rate reduction" combination now finally predicts an architecture that looks a lot like both of the previously-known trial-and-error-discovered transformer and diffusion architectures, and that in their experiments actually works about as well, or in some cases even a little better, than transformers and diffusion models for a variety of real tasks. This strongly suggests that their sparse rate reduction mathematical model for what should be optimized is now a close fit for what current transformer and diffusion architectures are actually doing.

  3. ^

    More specifically, the prediction is that in their architecture the features will fall into a large number of subspaces , each of variable small dimension , each individually linear (and in the paper modeled as having a uniform Gaussian probability distribution within it), with the individual subspaces being approximately-orthogonal to each other (probability distributions are modeled in the paper as being independent across subspaces), and that the representations of individual subspaces (once fully optimized) will tend to be both approximately orthogonal and approximately basis-aligned, each approximately spanned by a number of basis vectors matching its dimension. [This suggests an interpretation for the phenomenon of "feature splitting" observed in Towards Monosemanticity: Decomposing Language Models With Dictionary Learning: the cluster of features could be related to a subspace, and then the splitting process should limit at its dimension.] This is basically a generalization of Principal Component Analysis.

    How this claim would combine with the non-monosematicity implied by LLMs having many more features than internal dimensions is not fully clear to me, and will presumably hinge on the various occurrences of the word 'approximately' above: for an internal activation vector space with thousands of dimensions, there are vastly more directions that are approximately orthogonal and approximately basis aligned than there are dimensions. From what I understand so far of the paper, the optimization algorithm encourages the feature subspaces to span the activation vector space, to be normalized, to be approximately orthogonal, and to be as basis aligned as they can. If there were more common features than dimensions, I would a priori expect the most common ones to be most orthogonal and basis-aligned, and additional ones to slot themselves in, presumable in descending order of frequency, in ways that were as approximately orthogonal to all the other feature subspaces as possible while being as approximately basis aligned at possible: in an embedding space with  dimensions, there are billions of of directions where these approximations could be .

    The paper includes quite a number of interpretability results in section 4.3. They are for image nets, not LLMs, and I am not personally very familiar with interpretability for image nets, but some of them around image segmentation and subsegmentation of parts of animals (head, legs, etc) on the face of it look like the sort of thing one would hope for in image interpretability.

New Comment
4 comments, sorted by Click to highlight new comments since:

I second this request for a second opinion. It sounds interesting.

Without understanding, it deeply, there are a few meta-markers of quality:

  • they share their code
  • they manage to get results that are competitive with normal transformers (ViT's on imagenet, Table 1, Table 3).

However

  • while claiming interpretability, on a quick click through, I can't see many measurements or concrete examples of interpretability. There are the self attention maps in fig 17 but nothing else that strikes me.

Agreed. I'm working on a 3rd detailed reading, working my way through the math and references (the paper's around 100 pages, so this isn't a quick process) and will add more detail here on interpretability as I locate it. My recollection from previous readings is that the interpretability analysis they included was mostly of types commonly done on image-processing networks, not LLM interpretability, so I was less familiar with it.

I'm played around with ELK and so on. And my impression is that we don't know how to read the hidden_states/residual stream (beyond like 80% accuracy, which isn't good enougth. But learning a sparse representation (e.g. the sparse autoencoders paper from Anthropic) helps, and is seen as quite promising by people across the field (for example EleutherAI has a research effort).

So this does seem quite promising. Sadly, we would really need a foundation model of 7B+ parameters to test it well, which is quite expensive in terms of compute.

Does this paper come with extra train time compute costs?

Their interpretability results seem to be in section 4.3, and there are quite a number of them from a variety of different networks. I'm not very familiar with interpretability for image nets, but the model spontaneously learning to do high-quality image segmentation was fairly striking, as was it spontaneously doing subsegmentation for portions of objects such as finding a neuron that responded to animal's heads and another that responded to their legs. That certainly looks like what I hope for if told that an image network was highly interpretable. The largest LLMs they trained were around the size of BERT and GPT-2, so their performance was pretty simple.