Since people are talking about the NTK/GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/GP models of neural networks can't learn features. By 'feature learning' I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then 'fine-tuned' with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/GP models can't do it at all.
The reason for this is pretty simple. During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all. Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only 'transfer' that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn't provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well -- the distribution of functions computed by intermediate neurons doesn't change after conditioning on the outputs, so networks sampled from the GP posterior wouldn't be useful for transfer learning either.
This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. 'GP/NTK performs similarly to SGD on simple tasks' has been found before, but it tends to break down as the tasks become more complex.[3]
So are there any theoretical models of neural nets which are able to incorporate feature learning? Yes. In fact, there are a few candidate theories, of which I think Greg Yang's Tensor Programs is the best. I got all the above anti-NTK/GP talking points from him, specifically his paper Feature Learning in Infinite Width Neural Networks. The basic idea of this paper is pretty neat -- he derives a general framework for taking the 'infinite-width-limit' of 'tensor programs', general computation graphs containing tensors with a width parameter. He then applies this framework to SGD itself -- the successive iterates of SGD can be represented as just another type of computation graph, so the limit can be taken straightforwardly, leading to a infinite-width limit distinct from the NTK/GP one, and one in which the features computed by intermediate neurons can change. He also shows that this limit outperforms both finite-width nets and NTK/GP., and learns non-trivial feature embeddings. Two caveats: this 'tensor program limit' is much more difficult to compute than NTK/GP, so he's only actually able to run experiments on networks with very few layers and/or linear activations. And the scaling used to take the limit is actually different from that used in practice. Nevertheless, I think this represents the best theoretical attempt yet to capture the non-kernel learning that seems to be going on in neural nets.
To be clear, I think that the NTK/GP models have been a great advance in our understanding of neural networks, and it's good to see people on LW discussing them. However, there are some important phenomena they fail to explain. They're a good first step, but a comprehensive theoretical account of neural nets has yet to be written.[4]
You might be wondering how it's possible for the output function to change if none of the individual neurons' functions change. Basically, since the output is the sum of N things, each of them only needs to change by O(1/N) to change the output by O(1), so they don't change at all in the wide-width limit(See also my discussion with johnswentworth in the comments) ↩︎
Sort of. A more exact statement might be that the NTK can technically do transfer learning, but only trivially so, i.e. it can only 'transfer' to tasks to the extent that they are exactly the same as its original task. See this comment. ↩︎
In fairness to the NTK/GP, they also haven't been tried as much on more difficult problems because they scale worse than neural nets in terms of data(D^2*(kernel eval cost) in number of data points, since you need to compute the kernel between all points). So it's possible that they could do better if people had the chance to try them out more, iterate improved versions, and so on. ↩︎
I'll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting -- perhaps even implementing something like a simplicity prior over a large class of functions, which I'm pretty sure NTK/GP can't be ↩︎
Hmm, so regarding the linear combinations, it's true that there are some linear combinations that will change by Θ(1) in the large-width limit -- just use the vector of partial derivatives of the output at some particular input, this sum will change by the amount that the output function moves during the regression. Indeed, I suspect(but don't have a proof) that these particular combinations will span the space of linear combinations that change non-trivially during training. I would dispute "we expect most linear combinations to change" though -- the CLT argument implies that we should expect almost all combinations to not appreciably change. Not sure what effect this would have on the PCA and still think it's plausible that it doesn't change at all(actually, I think Greg Yang states that it doesn't change in section 9 of his paper, haven't read that part super carefully though)
So I think I was a bit careless in saying that the NTK can't do transfer learning at all -- a more exact statement might be "the NTK does the minimal amount of transfer learning possible". What I mean by this is, any learning algorithm can do transfer learning if the task we are 'transferring' to is sufficiently similar to the original task -- for instance, if it's just the exact same task but with a different data sample. I claim that the 'transfer learning' the NTK does is of this sort. As you say, since the tangent kernel doesn't change at all, the net effect is to move where the network starts in the tangent space. Disregarding convergence speed, the impact this has on generalization is determined by the values set by the old function on axes of the NTK outside of the span of the partial derivatives at the new function's data points. This means that, for the NTK to transfer anything from one task to another, it's not enough for the tasks to both feature, for instance, eyes. It's that the eyes have to correlate with the output in the exact same way in both tasks. Indeed, the transfer learning could actually hurt the generalization. Nor is its effect invariant under simple transformations like flipping the sign of the target function(this would change beneficial transfer to harmful). By default, for functions that aren't simple multiples, I expect the linear correlation between values on different axes to be about 0, even if the functions share many meaningful features. So while the NTK can do 'transfer learning' in a sense, it's about as weak as possible, and I strongly doubt that this sort of transfer is sufficient to explain transfer learning's successes in practice(but don't have empirical proof).
It's true that NTK/GP perform pretty closely to finite nets on the tasks we've tried them on so far, but those tasks are pretty simple and we already had decent non-NN solutions. Generally the pattern is '"GP matches NNs on really simple tasks, NTK on somewhat harder ones". I think the data we have is consistent with this breaking down as we move to the harder problems that have no good non-NN solutions. I would be very interested in seeing an experiment with NTK on, say, ImageNet for this reason, but as far as I know no one's done so because of the prohibitive computational cost.
Thanks for the link -- will read this tomorrow.
And thank you for engaging in detail -- I have also found this very helpful in forcing me to clarify(partially to myself) what my actual beliefs are.