There is a new paper by Liu et al. that claims to have understood the key mechanism underlying grokking (potentially even generalization more broadly).
They argue:
1. Grokking can be explained via the norm of the weights. They claim that there is a constant level of the weight norm that is optimal for generalization.
2. If there is an optimal level of the weight norm, the weight norm of your model after initialization can be either too low, too high or optimal. They claim that grokking is a phenomenon where we initialize the model with a large weight norm and it then slowly walks toward the optimal weight norm and then generalizes.

3. They also claim that you can get the same results as grokking but much faster if you set the weight norm correctly at every step.
4. They set the norm "correctly" by rescaling the weights after each unconstrained optimization step (so after every weight update loss.backward()?!)
Implications:
- I think they have found a very important insight of grokking and finding generalizing circuits more broadly.
- I'm still a bit skeptical of some of the claims and results. On some level "just fix your weight norm and the model generalizes" sounds too simple to be true for all tasks.
- I think this result could have big implications but I'm not yet sure whether they are positive or negative. On the one hand, finding generalizing circuits seems to solve some of the problems associated with bad out-of-distribution generalization. On the other hand, it likely speeds up capabilities.
I'm very unsure about this paper but intuitively it feels important. Thoughts?
The existence of an optimal L2 norm makes no sense at all to me. The L2 norm is an extremely unnatural metric for neural networks. For instance, in ReLU networks if you multiply all the weights in one layer by β and all the weights in the layer above by 1/β, the output of the network doesn't change at all, yet the L2 norm will have changed (the norm for those two layers will be β2|w1|2+1/β2|w2|2). In fact you can get any value for the L2 norm (above some minimum) you damn well please by just scaling the layers. An optimal average entropy of the output distribution over the course of training would make a hell of a lot more sense if this is somehow changing training dynamics.