Really cool. I read some of these kinds of papers last week, but this is better context on the topic. Redundancy seems like evidence in favor of a narrow loss basin, but e.g. the fact that fine-tuned BERT models generalize very differently is evidence of multiple local minima. Your guess that linear mode connectivity works in simple image classification domains but not in language models seems like the most likely answer to me, but I would be interested to see it tested.
Introduction
Last week's paper roundup (more or less by accident) focused mostly on path dependence of deep learning and the order of feature learning. Going forwards, I've decided to have an explicit focus for each week's roundup. This week's focus is on the structure/redundancy of trained models, as well as linear interpolations through parameter space.
I've also decided to publish each roundup on Monday morning (Edit: or try to, at any rate).
Papers
Residual Networks Behave Like Ensembles of Relatively Shallow Networks
My opinion:
This paper suggests that neural nets are redundant by default, which gives some intuition for why it's often possible to prune large fractions of a network's parameters without much impact on the test performance, as well as the mechanism by which residual connections allow for training deeper networks: residual connections allow shallow nets to communicate directly with the input / output space, so they allow for deep nets to be built from ensembling shallow nets.
I think it also points away from neural nets implementing a Kolmogorov or circuit simplicity prior.
On the Effect of Dropping Layers of Pre-trained Transformer Models
My opinion:
(see below)
Of Non-Linearity and Commutativity in BERT
My opinion:
These two papers tell us similar things about the function of transformer layers. They are fairly redundant, especially the later layers, with adjacent layers being more similar to each other than they are to more distant layers. This suggests transformers are also fairly ensembly.
Though the ROME paper implies that there's an asymmetry between what the early layers do (store and look up key-value memories) and what the later layers do (query early layer value outputs through their self-attention connections). Perhaps that's why the later layers cause a lower reduction in performance after being pruned?
Code here.
The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks
My opinion:
Past work demonstrated that image models with different random initializations converged to minima that have simple, but not linear, paths of constant loss between them. This paper argues that these paths become linear if you account for possible permutations of the model's weights. I think this is plausible for the sorts of wide image models trained on very limited data that people typically do mode connectivity experiments on, but I doubt it will hold for large language models or other systems that could plausibly scale to AGI.
The best parts of this paper are the empirical experiments. They find that the loss barrier on linear interpolations between trained models decreases with width and increases with both depth and the difficulty of the training data.
These results make sense to me. Linear mode connectivity requires that every linear combination of two networks also be an effective network. The wider the network, the more degrees of freedom there are in its hidden representations, and so the more room to avoid destructive interference between the hidden representations of the two networks. The harder the dataset, the more of that representational capacity each network's hidden representations will take up.
A network's forwards pass essentially sends a "message" from the input layer to the final layer, using an encoding defined by the weights. Linear mode connectivity between two networks implies that the encoding schemes of the two networks have little interference, such that both their "messages" can be sent simultaneously through the same channel. Note that word embeddings are often very anisotropic (lie in a narrow cone, rather than spread across the full representation space), and that neural net embeddings often have large first principle components that vary across initializations. Speculatively, these two patterns might contribute to the relative lack of interference between models during interpolation.
This paper's "theoretical result" is for MLPs with one hidden layer at random initialization (as in, no training at all). I was expecting them to use the NTK to approximate the training process, but it's just a probabilistic consequence of permutations increasing exponentially with layer width.
If anything, the fact that there's linear mode connectivity between untrained MLPs suggests that two models having linear mode connectivity tells you relatively little about their true degree of functional similarity.
This paper also investigates the impact of weight permutations on linear mode connectivity. They have less extensive empirical investigations, but have an effective algorithm for finding good weight permutations to support linear interpolations.
Note that neither paper uses Adam, which Analyzing Monotonic Linear Interpolation in Neural Network Loss Landscapes (below) finds can break a different type of linear interpolation property.
Code here.
Loss Surface Simplexes for Mode Connecting Volumes and Fast Ensembling
My opinion:
This paper presents an interesting method for estimating the geometry of the low-loss solution manifold found by SGD. Starting at a solution found by SGD, they essentially grow a maximum dimensional simplex whose vertices are solutions in parameter space and are constrained to have low loss. They then repeat this with many SGD solutions to build a collection of simplexes that approximate the low loss manifold's geometry. This lets them lower-bound the dimensionality of the low-loss manifold as being at least 10 dimension, though the authors are unable to create simplexes of more than 10 dimensions that have non-trivial hypervolume.
The paper also finds that averaging over the vertices of simplexes can improve model robustness, further evidence that mode connected networks can still implement different functions with different generalization behaviors.
Code here.
Linear Connectivity Reveals Generalization Strategies
My opinion:
The previous papers indicated that linearly connected basins could contain models implementing different functions with different generalization behavior. In contrast, this paper trains BERT models and finds they enter two basins without linear mode connectivity between them, and that these basins correspond to different functional solutions to the training data, which solve the training data using very different strategies.
This paper does not check for linear mode connectivity under weight permutations, but I wouldn't be surprised if the two basins remained unconnected even after allowing for permutations.
Code here.
Analyzing Monotonic Linear Interpolation in Neural Network Loss Landscapes
My opinion:
Prior work indicates that the linear path from initialization to the converged solution has monotonically decreasing loss. This paper tests monotonic linear interpolation (MLI) across various training configurations, finding that it often fails to hold for networks trained with Adam or with large SGD learning rates. The find that networks that move further from initialization tend to have more curved optimization trajectories, and that the MLI property is less likely to hold for these networks.
Code here.
Revisiting Model Stitching to Compare Neural Representations
My opinion:
This paper presents an interesting tool for comparing the features extracted by different models. They take the top and bottoms of two trained models, freeze their parameters, then "stitch" them together by learning a linear transformation between the embedding spaces of the two models. They find model stitching works well across architectures, datasets and training processes.
The fact that so many models can be stitched together lends support to feature universality and natural abstractions, and suggests different architectures are reasonably consistent in the types of features they extract. Similarly, the fact that models at different points in the training process can be stitched together suggests a degree of stability in the model's representations throughout training.
Also, this paper suggests GPT-2 word embeddings and human neurological activations during language processing can be similarly stitched together with a linear transform.
BERTs of a feather do not generalize together: Large variability in generalization across models with similar test set performance
My opinion:
This paper compares the generalization behavior of different BERT finetuning runs, where only the classification head initialization and training data order are varied during the training, though they do backprop through the entire BERT model during training. They find that different training runs have very different generalization behaviors (when evaluated on probing data specifically crafted to highlight which linguistic structures the models had learned to use to make classifications), but very similar within distribution behaviors.
This implies that using test data from the same distribution as the training data is not enough to pick up on differences in the generalizations that different trained networks learn. If the authors had just used test data to evaluate the finetuned models' generalizations, they'd have concluded that there was very little variability between the models.
(Thanks to Zac Hatfield-Dodds for bringing this paper to my attention)
Conclusion
My impression after reading these and similar papers is that results from "toy" settings often do not reflect those found in more realistic settings. A lot of these "optimization geometry" style investigations use very wide models that are massively undertrained on relatively easy datasets (usually of images). The consequence is that they operate in regimes where NTK and neural network Gaussian processes give good approximations of network training dynamics. However, these approximations do not hold well for larger models solving harder problems.
This raises the question of whether the empirical results we get from large networks such as GPT-3 will extend to whatever networks eventually implement AGI-level capabilities. I am hopeful that the differences in learning trajectories between toy networks and more powerful systems represent a single "phase transition", primarily caused by moving out of the NTK regime.
I do expect there is another "phase transition" in the inductive biases of AGI learning trajectories when the AI system becomes capable of actively trying to refine its own abstractions to improve its future thinking. Though, I think we can still study low-capabilities networks with such dynamics by studying active learning for GPT-3 levels systems.
Anyways, I hope readers find these papers useful for their own research. Please feel free to discuss the listed papers in the comments or recommend additional papers to me.
Future
For next week's roundup, I'm thinking the focus will be on using interpretability tools to guide a neural net's learning process. There's apparently a fair bit of work in this space. E.g., this review paper.
My other candidate focuses are:
Let me know if there are any topics you're particularly interested in.