The paper argues that auto-regressive transformers implement in-context learning via gradient-based optimization on in-context data.
The authors start by pointing out that with a single linear self-attention (LSA) layer (that is, no softmax), a Transformer can implement one step of gradient descent on the l2 regression loss (a fancy way of saying w -= LR * (w x-y)x^T), and confirm this result empirically. They extend this result by showing that an N-layer LSA-only transformer is similar to N-steps of gradient descent for small linear regression tasks, both in and out of distribution. They also find that the results pretty much hold with softmax self-attention (which isn’t super surprising given you can make a softmax pretty linear).
Next, they show empirically that the forward pass of a small transformer with MLPs behaves similarly to an meta-learned MLP + one step of gradient descent on a toy non-linear regression task, again in terms of both in-distribution and OOD performance.
They then show how you can interpret an induction head as a single step of gradient descent, and provide circumstantial evidence that this explains some of the in-context learning observed in Olsson et al 2022. Specially, they show that 1) a two layer attention-only transformers converge to loss consistent with one step of GD on this task, and 2) the first layer of the network learns to copy tokens one sequence position over in the first layer, prior to the emergence of in-context learning.
(EDIT:) davidad says below:
this is strong empirical evidence that mesa-optimizers are real in practice
Personally, while I think you could place this in the same category as papers like RL^2 or In-context RL with Algorithmic Distillation, which also show mesa optimization, I think the more interesting results are the mechanistic ones -- i.e., that some forms of mesa optimization in the model seem to be implemented via something like gradient descent.
(EDIT 2) nostalgebraist pushes back on this claim in this comment:
Calling something like this an optimizer strikes me as vacuous: if you don't require the ability to adapt to a change of objective function, you can always take any program and say it's "optimizing" some function. Just pick a function that's maximal when you do whatever it is that the program does.
It's not vacuous to say that the transformers in the paper "implement gradient descent," as long as one means they "implement [gradient descent on loss]" rather than "implement [gradient descent] on [ loss]." They don't implement general gradient descent, but happen to coincide with the gradient step for loss.
(Nitpick: I do want to push back a bit on their claim that they've "mechanistically understand the inner workings of optimized Transformers that learn in-context", since they've only really looked at the mechanism of how single layer attention-only transformers perform in-context learning. )
I think the claim that an optimizer is a retargetable search process makes a lot of sense* and I've edited the post to link to this clarification.
That being said, I'm still confused about the details.
Suppose that I do a goal-conditioned version of the paper, where (hypothetically) I exhibit a transformer circuit that, conditioned on some prompt or the other, was able to alternate between performing gradient descent on three types of objectives (say, L1, L2, L\infty) -- would this suffice? How about if, instead, there wasn't any prompt that let me switch between three types of objectives, but there was a parameter inside of the neural network that I could change that causes the circuit to optimize different objectives? How much of the circuit do I have to change before it becomes a new circuit instead of retargeting the optimizer?
I guess part of answer to these questions might look like, "there might not be a clear cutoff, in the same sense that there's not a clear cutoff for most other definitions that we use in AI alignment ('agent' or 'deceptive alignment' for example)", while another part might be "this is left for future work".
*This is also similar to the definition used for inner misalignment in Shah et al's Goal Misgeneralization paper: