Thanks to Dan Roberts and Sho Yaida for comments on a draft of this post.
In this post, I would like to draw attention to the book Principles of Deep Learning Theory (PDLT), which I think represents a significant advance in our understanding of how neural networks work [1]. Among other things, this book explains how to write a closed-form formula for the function learned by a realistic, finite-width neural network at the end of training [2] to an order of approximation that suffices to describe representation learning, and how that formula can be interpreted as the solution to a regression model. This makes manifest the intuition that NNs are doing something like regression, but where they learn the features appropriate for a given dataset rather than having them be hand-engineered from the start.
I've condensed some main points from the 400-page book into an 8-page summary here:
Review of select results from PDLT
(Other good places to learn about the book, though perhaps with less of a focus on AI-safety-relevant parts, include this series of five lectures given by the authors at a deep learning summer school or this one-hour lecture for a non-expert audience.)
For those who have been following the discussions of ML theory on this forum, the method used in the book is to go to the next-to-leading order in a 1/width expansion. It thus builds on recent studies of infinitely wide NNs that were reviewed in the AF post Recent Progress in the Theory of Neural Networks [3]. However, by going beyond the leading order, the authors of PDLT are able to get around a key qualitative shortcoming of the earlier work in that infinitely wide NNs can't learn features. The next-to-leading order formula also introduces a sum over many steps of gradient descent, getting around an objection [4] that the NTK/infinite width limit may not be applicable to realistic models since in that limit, we can land on the fully trained model after just one fine-tuned training step.
I think that this work could have significant implications for AGI forecasting and safety (via interpretability), and deserves to be better appreciated in this community. For example,
- In AGI forecasting, an important open question is whether the strong scaling hypothesis holds for any modern architectures. (For example, the forecasts in Ajeya Cotra's Bio-Anchors report are conditioned on assuming that 2020 algorithms can scale to TAI.) A longstanding challenge for this field is that as long as we treat neural networks as black boxes or random program search, it's hard to reason about this question in a principled way. But I think that by identifying a space of functions that realistic NNs end up learning in practice (<< the space of all neural networks with finely-tuned weights!), the approach of PDLT gives us a way to start to reason about it. For example, despite the existence of the universal approximation theorem, I think the results of PDLT can be used to rule out the (strawmannish) hypothesis that feedforward MLPs can scale to AGI (see my review of the Bio-Anchors report for more on this point). As such, it could be really interesting to generalize PDLT to other architectures.
- In mechanistic interpretability, a basic open question is what the fundamental degrees of freedom are that we should be trying to interpret. A lot of work has been done under the assumption that we should look at the activations of individual neurons, but there's naively no reason that semantically meaningful properties of a dataset must align with individual neurons after training, and even some interesting counterexamples [5]. By finding a dual description of a trained NN as a trained regression model, PDLT seems to hint that a (related, but) different set of degrees of freedom -- the effective features (32) in the above-linked note -- may be more natural objects to look at. It would be really interesting to see if this turns out to be the case [6].
- More generally, a dream for interpretability research would be if we could reverse-engineer our future AI systems into human-understandable code. If we take this dream seriously, it may be helpful to split it into two parts: first understanding what "programming language" an architecture + learning algorithm will end up using at the end of training, and then what "program" a particular training regimen will lead to in that language [7]. It seems to me that by focusing on specific trained models, most interpretability research discussed here is of the second type. But by constructing an effective theory for an entire class of architecture that's agnostic to the choice of dataset, PDLT is a rare example of the first type. So it could be not only useful but also totally complementary to other agendas to try to develop it further and/or generalize it to new architectures as they come along.
- ^
See here for an earlier, though shorter discussion of this book on LW.
- ^
As a function of its architecture + weight initialization + training set + learning algorithm. The formula is (.154) on page 1 of the note that I link to below.
- ^
See also yesterday's post Neural Tangent Kernel Distillation.
- ^
As raised e.g. by Paul Christiano here.
- ^
See the recent Anthropic paper Toy Models of Superposition for a discussion of this and related issues.
- ^
For example, one could try to see which dataset examples or synthetic data would maximally activate the effective features, generalizing the experiments done on neurons in the Circuits thread. (However, a caveat for this project idea is that there are a huge number of effective features in PDLT, scaling as the number of weights instead of the number of neurons! So one might have to start with a really small toy model, or be clever about picking a subset / combination of features to visualize.)
It could also be interesting to understand how the PDLT effective theory fits conceptually with other ideas from the interpretability and broader "science of ML" literature. For example, how should we think of compositional circuits in the dual frame, which seems to put all effective features on the same footing instead of having some be built from others? (Perhaps this isn't a meaningful question to ask since the compositional circuits in vision models are themselves a fuzzy emergent description. But in that case, can we generalize PDLT to transformers, and then understand crisp emergent circuits like modular arithmetic circuits in the dual frame?) Or since the dual model has a huge number of effective features, might there be a lottery ticket hypothesis at the level of the features?
- ^
Or to put it another way, if we want to understand what cognition a trained AI system is performing at inference time, there's both a kinematic aspect of claiming that some degrees of freedom in the system (or the way that they activate or something) approximately encode some human-understandable concepts in a Platonic latent space (e.g. paperclips), and a dynamical aspect of how those concepts get put together (e.g. into a plan to turn us into paperclips). The space of allowed dynamics is what I'm calling the "programming language" per Chris's metaphor.
Thank you for the discussion!
Let us start by stressing that, of course, the maximal-update parametrization is definitely an intriguing recent development, and it would be very interesting to find tools to be able to understand the strongly-coupled regime in which it resides.
Now, it seems like there are two different issues tangled in this discussion: (i) is one parameterization "better" than another in practice?; and (ii) is our effective theory analysis useful in practically interesting regimes?
We'd like to also emphasize that, even if you are against NTK parameterization in practice and don't think it's relevant at all -- a position we don't hold, but maybe one might -- perhaps it's still worth pointing out that our work provides a simple solvable model of representation learning from which we might learn some general principles that may be applicable to safety and interpretability.
With those said, let us respond to your comments point by point.
We aren't sure if that's accurate: empirically, as nicely described in Jennifer's 8-page summary (in Sec. 1.5), many practical networks -- from a simple MLP to the not-very-simple GPT-3 -- seem to perform well in a regime where the depth-to-width aspect ratio is small (like 0.01 or at most 0.1). So, the leading-order perturbative description would be fairly accurate for describing these practically-useful networks.
Moreover, one of the takeaways from "effective theory" descriptions is that we understand the truncation error: in particular, the errors from truncation will be of order (depth-to-width aspect ratio)^2. So this means we can estimate what we would miss by truncating the series and learn that sometimes -- if not most of the time -- we really don't have to compute these extra terms.
It is true that decreasing the depth-to-width aspect ratio reduces the representation-learning capability of the network and -- to the extent that representation learning is useful for the task -- doing so would degrade the performance. But (i) let us reiterate that, as alluded to above, empirically networks seem to operate well in the perturbative regime where the aspect ratio is small and (ii) the converse is not true (i.e., it is not beneficial to keep increasing the aspect ratio indefinitely), as we illustrate in responding to the following point.
Actually, that last point is not always the case. One of the results from our book is that while increasing the depth-to-width ratio leads to more representation learning, it also leads to more fluctuations in gradients from random seed to random seed. Thus, the deeper your network is for fixed width, the harder it is to train, in the sense that different realizations will not only behave differently, but also will likely not be critical (i.e., it will not be on what is sometimes referred to as the "edge of chaos" and it will suffer from exploding/vanishing gradients). And this last observation is true for both the NTK parametrization and maximal-update parametrization, so by your logic, we would be screwing up no matter which parametrization we use. :)
As it turns out, this tradeoff between the benefit of representation learning and the cost of seed-to-seed fluctuations leads to the concept of the optimal aspect ratio where networks should perform the best. Empirical results indirectly indicate that this optimal aspect ratio may be in the perturbative regime; in the Appendix of our book, we also did a calculation using tools from information theory that gives evidence that the optimal depth-to-width ratio is in the perturbative regime.
We don't think this is the case. Both NTK and maximal-update parametrizations can avoid converging to kernel limits and can allow features to evolve: for the NTK parametrization, we need to keep increasing the depth in proportion to the width; for the maximal-update parametrization, we need to keep the depth fixed while increasing the width.
Sho and Dan