The correlation between training loss and LLC is especially unexpected to us
It's not unusual to see an inverse relationship between loss and LLC over training a single model (since lower loss solutions tend to be more complex). This can be seen in the toy model of superposition setting (plot attached) but it is also pronounced in large language models. I'm not familiar with any examples that look like your plot, where points at the end of training runs show a linear relationship.
In our toy example, I would intuitively associate the LLC with the test losses rather than train loss. For training of a single model, it was observed that test loss and LLC are correlated. Plausibly, for this simple model (final) LLC, train loss, and test loss, are all closely related.
We then record the training loss, weight norm, and estimated LLC of all models at the end of training
For what it's worth, Edmund Lau has some experiments where he's able to get models to grok and find solutions with very large weight norm (but small LLC). I am not tracking the grokking literature closely enough to know how surprising/interesting this is.
We haven't seen that empirically with usual regularization methods, so I assume there must be something special going on with the training set up.
I wonder if this phenomenon is partially explained by scaling up the embedding and scaling down the unembedding by a factor (or vice versa). That should leave the LLC constant, but will change L2 norm.
For models that do not grok either group, we observe both examples where the LLC stays large throughout training and examples where it falls
What explains the difference in scale of the LLC estimates here and in the earlier plot, where they are < 100? Perhaps different hyperparameters?
For that earlier section, we used smaller models trained on intersect (4,000 parameters) instead of intersect (80,000 parameters) -- the only reason for this was to allow for a larger sample size of 10,000 models with our compute budget. All subsequent sections use the models.
Introduction
It's now well known that simple neural network models often "grok" algorithmic tasks. That is, when trained for many epochs on a subset of the full input space, the model quickly attains perfect train accuracy and then, much later, near-perfect test accuracy. In the former phase, the model memorizes the training set; in the latter, it generalizes out-of-distribution to the test set.
In the algorithmic grokking literature, there is typically exactly one natural generalization from the training set to the test set. What if, however, the training set were instead under-specified in such a way that there were multiple possible generalizations? Would the model grok at all? If so, which of the generalizing solutions would it choose? If the model followed Occam's razor, it would choose the simplest solution -- but what does "simplest" mean here? We explore these questions for the task of computing a finite group operation.
Setup
This section assumes some basic familiarity with group theory and can be skipped. The point is just that each model is trained on the intersection of two datasets G1 and G2; the intersection of the two test sets always has size ≥50%.
In existing work on grokking finite group operations (e.g. Chughtai et al.), a one-hidden-layer MLP model with two inputs f:RG×RG→RG is trained on the operations of a finite group G. The model takes as input a pair of elements x,y∈G, encoded as one-hot vectors in RG, and is expected to output logits maximized at the product xy∈G. Thus, the input space is the set of all pairs G2, and the model is evaluated on its accuracy over the entire test space:
acc(f)=|G|−2∑x,y∈G1{argmaxz∈Gf(x,y)z=xy}.Previous work finds that training on an iid subsample of 40% of the total input space (so the training set has 0.4|G|2 points) is enough to grok the full multiplication table for various choices of G. (The most well-studied choice is the symmetric group G=S5.) In our setup, we leave everything the same except for the choice of training set. We now choose two groups G1 and G2 such that |G1|=|G2|. Thinking of the two groups as two operations ∗1 and ∗2 on the same underlying set of elements, our ambiguous training set is the set of all pairs of elements such that the two group operations agree:
Xtrain={(x,y):x∗1y=x∗2y}.Note then that Xtrain can be completed to the full multiplication table for either G1 or G2.
To ensure that there are enough elements in Xtrain for grokking to be possible, we need to construct G1 and G2 with large overlap. One way is to set G1=H×Z/2Z and G2=H⋊φZ/2Z for some group H and φ∈Aut(H). Then, by construction, whenever k1=0∈Z/2Z,
(h1,k1)∗1(h2,k2)=(h1h2,k1k2)=(h1φk1h2,k1k2)=(h1,k1)∗2(h2,k2).Thus, the overlap between G1 and G2 is at least 50%. (In fact, it is generally somewhat larger than 50%, owing to the fixed points of φ.) All examples we discuss will be of this form.
Experiments
Ambiguous grokking
Grokking either group
We run the training setup described above with G1=S5=A5⋊Z/2Z and G2=A5×Z/2Z. The intersection between the two groups' multiplication tables has size 7920/1202=55%. We use this entire intersection as the training set and train 100 models with varying initialization seeds, all other hyperparameters held fixed. In this setup, the vast majority of models do not fully grok either solution (~90%), and instead just memorize Xtrain. However, we do find both models that grok G1 (~4%) and G2 (~6%).[1]
Grokking the intersect
Although models often fail to grok either of the two groups, they always successfully grok the intersection: when we train models on an iid random 60% subset of the intersection (so 33% of the full input space), we find that they always attain full accuracy on the entire intersection and in some cases full accuracy on either G1 or G2.
Grokking only one group
When we set G1=Z/48Z×Z/2Z and G2=Z/48Z⋊φZ/2Z, where φ∈Aut(Z/48Z) maps x↦25x (in this case, the intersection size is 75%), the model only groks G1. Our intuition is that G1, being commutative, is much simpler than G2, which is not, and thus may be preferred by the model. This may just be a coincidence, however: we have not been able to find an intrinsic measure of group complexity such that models always prefer simpler groups. See more discussion below.
No grokking
When we set G1=Z/100Z and G1=Z/50Z×Z/2Z (intersection size 51%) the model never groks either group. However, this example is a little unusual: the proportion of label classes in the intersect training set is non-uniform, and thus differs from the proportion over the entire input space (for either G1 and G2). We speculate that this class imbalance may be the reason for lack of grokking.
Measuring complexity
Complexity of the grokked solution
Intuitively, in cases where the model may grok either of two generalizing solutions, it should prefer the simpler of the two. There are two distinct things that we could mean by a solution implemented by a model being simple:
To test whether LLC or some measure of circuit efficiency is able to predict which group the model groks, we train 10,000 models on the intersection of G1=S4 and G2=A4×Z/2Z. In this setting, we again find that most models don't grok (93%). Among those that do grok there is a strong preference for S4 (6.5%) over A4×Z/2Z (0.5%). We then record the training loss, weight norm, and estimated LLC of all models at the end of training.
All three measures are somewhat predictive of what group is grokked, in a direction that aligns with our intuition that the model prefers simpler solutions. Models that grok S4 (6.5%) tend to have lower training loss, weight norm, and LLC, and models that grok A4×Z/2Z (0.5%) tend to have larger LLC and training loss, though still lower than what is typical for models that do not grok at all. However, we also find plenty of examples of models that do not grok either group yet still have low LLC and training loss. Possibly, these models are finding solutions that do not correspond to either of the two groups (or any group), yet still are "simple" in some appropriate sense.[3]
Surprisingly, however, we observe that final training loss and weight norm are moderately correlated, and that final training loss and LLC are highly correlated.
The correlation between training loss and LLC is especially unexpected to us, and we do not have a good explanation for its presence. Since (to our knowledge) this correlation has not been noted in other settings, we suspect that it is a quirk specific to 1-hidden layer MLPs trained on groups. In any case, while our results are suggestive that models prefer simpler solutions as measured by some combination of circuit efficiency and LLC, this correlation means that we cannot disentangle the two measures in the groups setting.
Complexity over time
Besides checking model complexity at the end of training, we also measure it across training time. For G1=S5 and G2=A5×Z/2Z, the LLC tracks grokking in the sense that, whenever the model groks either of the two groups, the LLC decreases sharply at the same time that the test accuracy rises. For models that do not grok either group, we observe both examples where the LLC stays large throughout training and examples where it falls. As aforementioned, we speculate that these cases correspond to the model learning some simple solution distinct from either of the two groups.
Are there any complexity measures that track grokking over time better than the LLC? We check those listed in Jiang et al. For simplicity, we measure these across checkpoints for a model trained on an iid subset of S5 instead of an intersect set -- these models consistently grok. We notice that the LLC estimates in the iid case tend to be smoother over time than in the intersection experiments.
From these plots, it appears that LLC (bottom left) best tracks the generalization error, measured as the difference between train accuracy and test accuracy. Inverse margin and training loss (last two plots) also do well (perhaps this is related to the high correlation between training loss and LLC at the end of training, discussed above), but they are both large at the start of training, when generalization error is low because both training and test accuracy are low. The LLC is correctly low both at the beginning of training, before memorization, and at the end, after grokking.
Determination and differentiation
Perturbation sensitivity
Somewhat separately from the previous investigations, one might wonder when the model "decides" which of the groups (if any) it will grok. In a literal sense, the answer is that the model's fate is determined at initialization, because in our experiments there is no stochasticity during training (we use full-batch gradient descent). However, this is not really a satisfying answer. Rather, when we say that the model has "decided" on a future solution, we should expect that, from the decision point onwards, its decision is robust: small perturbations in training should not be able to make the model "change its mind".
Hence, to measure the point at which a model decides its fate, we make small Gaussian perturbations to the model weights at evenly spaced checkpoints throughout training. We then continue training anew from these perturbed weights. We find evidence that the model makes a decision not at initialization but still well before its choice is apparent in its test-set behavior.[4]
The example above is particularly interesting in that the model briefly veers towards A5×Z/2Z around epoch 600, corresponding to a bump in perturbation sensitivity, before returning on its path towards S5.[5]
Total variation
For models trained on the intersection of two groups, we notice that those that grok either of the two groups tend to have more stable outputs late in training than those that grok neither. We quantify this by measuring total variation in logits over a fixed training interval:
TV[i,j]:=|G|−2j−1∑n=i∑x,y∈G1{fθn(x,y)≠fθn+1(x,y)},where θn are the model parameters at epoch n. For models trained on the intersection of G1=S4 and G2=A4×Z/2Z for 3000 epochs, we measure TV[2500,3000].
Models that grok either of G1,G2 always "lock in" their solutions -- the functions they compute no longer change after grokking. Those that grok neither often continue to oscillate in logit space even late in training.[6] However, similarly to the case with the complexity measure experiments, there are many examples of models that grok neither and yet still have zero total variation, possibly corresponding to simple solutions distinct from both G1 and G2 that we are unaware of.
Note also that this is a purely test-set behavior. On the training set, all models have zero total variation by the end of training: once they attain perfect training accuracy, they no longer vary their training-set outputs.
Determination across distribution shift
As seen above, the models that are trained on an intersection dataset and eventually grok have their fate stably determined relatively early in the training process. Is this stability a property of all possible parameter settings that implement the same solution? We investigate this question by first training models on iid data subsampled from G1=S5 so that the models consistently grok; we then "transplant" these models to the intersection of G1=S5 and G2=A5×Z/2Z and resume training. Perhaps surprisingly, many model instances (7%) partially ungrok G1, while retaining perfect accuracy on the intersection. A few instances then proceed to regrok G1, returning to perfect test accuracy later in training.[7] Repeating the same experiment with the roles of G1 and G2 swapped results in the same behavior (8% ungrokked).
Training Jacobian
The training Jacobian (Belrose & Scherlis) is the Jacobian of the training map from initial model parameters to final model parameters.[8] Thus, if the model has m parameters, then the training map is of type Rm→Rm and the training Jacobian is a matrix Rm×m. In experiments with 25 epochs of training and without weight decay (hence no grokking), Belrose & Scherlis find that the training Jacobian preserves most directions in parameter space (corresponding to singular values ≈1 and left singular vectors approximately equal to right singular vectors). The subspace approximately preserved by the training Jacobian is referred to as the bulk.
We compute the training Jacobian in our groups setting and observe that
The results are similar across both models trained on group intersections and those trained on iid subsamples of one group.
Assuming the latter observation isn't an artifact of numerical instability, (which is entirely possible) we think it suggests that the limiting infinite-training map, mapping initial parameters to the fully converged final parameters, is discontinuous as a function Rm→Rm. That is, arbitrarily small perturbations to the initial weights might cause non-negligible changes to the model's final weights.[9] When the training set is iid, the different resulting model parameters all lead to the same model behaviors -- neural networks are non-identifiable. When the training set is under-determined, these changes in weight space may manifest as changes in model behavior.
As an aside, we note that this high sensitivity to model initialization seems to somewhat contradict (a sufficiently strong version of) the lottery ticket hypothesis. If there really existed a subnetwork at initialization that is significantly more amenable to training than the rest of the model, we'd expect its prominence to be at least somewhat robust to perturbations.
Discussion
We anticipate that ambiguous out-of-distribution generalization is a phenomenon that may arise not only in toy algorithmic tasks but also in more realistic settings. Indeed, recent work (Qin et al., Mészáros et al., Reizinger et al.) finds that the ability of language models to apply grammatical rules generalizes unstably out-of-distribution. Our concern is that ambiguous generalization may pose serious obstacles to the safety and alignment of LLM-based AI systems. Safety training for large language models is typically performed with datasets much smaller than those used for pre-training. Thus, safety training may be significantly more prone to under-specification in a manner leading to ambiguous generalization. Such under-specification may manifest in what is called deceptive alignment: the model appears to its designers to be well-aligned within the training environment, but in reality learns a solution that is harmful in deployment. Heuristic counting arguments (Hubinger, Carlsmith) suggest that, in the presence of such ambiguity, true alignment may be vanishingly unlikely compared to deceptive alignment or scheming.
Our main motivation for studying models trained on the intersection of two groups was to 1) exhibit a crisp and concrete example of ambiguous generalization and 2) use the setting as a testbed for various hypotheses surrounding this phenomenon. Regarding 2), our hope was to relate some precise measure of a solution's complexity to the frequency with which it is learned, thus providing a quantitative version of the aforementioned counting arguments. While we were not able to fully attain this goal, we did find some evidence in favor of the local learning coefficient and circuit complexity. On the other hand, we were not able to disentangle these two measures in this toy setting, perhaps suggesting the need to move to other, more realistic experiments. Regardless, we continue to believe that ambiguous out-of-distribution generalization is an important, safety-relevant phenomenon that merits further study.
Acknowledgements: This work was mainly conducted during MATS 6.0 and 6.1. Many thanks to Jesse Hoogland, Daniel Filan, and Jacob Drori for helpful conversations. Wilson Wu was supported by an LTFF grant.
Since in this experiment the sample size is small and grokking is somewhat rare, the proportion estimates should be treated as fairly low-confidence. In particular, we don't claim that the model prefers to grok A5×Z/2Z over S5. Our experiments with S4 and A4×Z/2Z have a 100x larger sample size, and thus for those groups we are able to draw more confident conclusions about model preferences.
One hypothesis we had was that the model would prefer the group with lower cs(G):=∑ρ∈Irrep(G)dim(ρ)s for some s>2. This quantity is larger for groups with larger irreps, and in particular is minimized for abelian groups. It also appears with s=2.5 in the expression for maximum margin computed in Theorem 9 of Morwani et al.'s paper. However, in the limited experiments we ran, this hypothesis seemed not to pan out.
The functions computed by non-grokking models on the full input space are diverse and, in general, not simply combinations of G1 and G2.
One might draw an analogy to cellular determination. "When a cell 'chooses' a particular fate, it is said to be determined, although it still 'looks' just like its undetermined neighbors. Determination implies a stable change -- the fate of determined cells does not change." (source)
A speculative, likely incorrect cartoon picture: around epoch 300, the model is funneled into a slide in the loss landscape whose final destination is at S5. Around epoch 600, there's a sharp bend in the slide (and/or it becomes narrower and/or shallower...). If the model takes the bend too quickly, it shoots over the edge and lands in A5×Z/2Z.
We use a constant learning rate schedule. Likely this effect would not have been as apparent if we had instead used, say, cosine annealing.
Performing some cursory mechanistic interpretability on these examples, we find that the original grokked parameters and the regrokked parameters generally differ, though they do tend to share many neurons in common (in the sense that the same neuron indices are supported in the same irreps). Since regrokking is a rare phenomenon that we only quickly spot-checked, we don't claim that this observation is representative.
It follows immediately from the chain rule that, assuming the final model parameters are at a local minimum, the training Jacobian must lie in the null space of the loss Hessian at the final parameters. That is, the training Jacobian is only nontrivial if the model is singular.
On the other hand, it is probably not too hard to show that any finite number of iterations of SGD, Adam, etc., is continuous. Only in the infinite training limit can discontinuities appear.