Since people are talking about the NTK/GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/GP models of neural networks can't learn features. By 'feature learning' I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then 'fine-tuned' with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/GP models can't do it at all.

The reason for this is pretty simple. During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all. Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only 'transfer' that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn't provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well -- the distribution of functions computed by intermediate neurons doesn't change after conditioning on the outputs, so networks sampled from the GP posterior wouldn't be useful for transfer learning either.

This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. 'GP/NTK performs similarly to SGD on simple tasks' has been found before, but it tends to break down as the tasks become more complex.[3]

So are there any theoretical models of neural nets which are able to incorporate feature learning? Yes. In fact, there are a few candidate theories, of which I think Greg Yang's Tensor Programs is the best. I got all the above anti-NTK/GP talking points from him, specifically his paper Feature Learning in Infinite Width Neural Networks. The basic idea of this paper is pretty neat -- he derives a general framework for taking the 'infinite-width-limit' of 'tensor programs', general computation graphs containing tensors with a width parameter. He then applies this framework to SGD itself -- the successive iterates of SGD can be represented as just another type of computation graph, so the limit can be taken straightforwardly, leading to a infinite-width limit distinct from the NTK/GP one, and one in which the features computed by intermediate neurons can change. He also shows that this limit outperforms both finite-width nets and NTK/GP., and learns non-trivial feature embeddings. Two caveats: this 'tensor program limit' is much more difficult to compute than NTK/GP, so he's only actually able to run experiments on networks with very few layers and/or linear activations. And the scaling used to take the limit is actually different from that used in practice. Nevertheless, I think this represents the best theoretical attempt yet to capture the non-kernel learning that seems to be going on in neural nets.

To be clear, I think that the NTK/GP models have been a great advance in our understanding of neural networks, and it's good to see people on LW discussing them. However, there are some important phenomena they fail to explain. They're a good first step, but a comprehensive theoretical account of neural nets has yet to be written.[4]


  1. You might be wondering how it's possible for the output function to change if none of the individual neurons' functions change. Basically, since the output is the sum of N things, each of them only needs to change by O(1/N) to change the output by O(1), so they don't change at all in the wide-width limit(See also my discussion with johnswentworth in the comments) ↩︎

  2. Sort of. A more exact statement might be that the NTK can technically do transfer learning, but only trivially so, i.e. it can only 'transfer' to tasks to the extent that they are exactly the same as its original task. See this comment. ↩︎

  3. In fairness to the NTK/GP, they also haven't been tried as much on more difficult problems because they scale worse than neural nets in terms of data(D^2*(kernel eval cost) in number of data points, since you need to compute the kernel between all points). So it's possible that they could do better if people had the chance to try them out more, iterate improved versions, and so on. ↩︎

  4. I'll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting -- perhaps even implementing something like a simplicity prior over a large class of functions, which I'm pretty sure NTK/GP can't be ↩︎

New Comment
33 comments, sorted by Click to highlight new comments since:

(moved from LW to AF)

Meta: I'm going to start commenting this on posts I move from LW to AF just so there's a better record of what moderation actions I'm taking.

This seems like a great thing to do. Mild preference for calling it "added it to AF", just so that people don't get confused that this means content will disappear from LW.

Thx for doing this!

I'll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting -- perhaps even implementing something like a simplicity prior over a large class of functions, which I'm pretty sure NTK/GP can't be

Wait, why can't NTK/GP be implementing a simplicity prior over a large class of functions? They totally are, it's just that the prior comes from the measure in random initialization space, rather than from the gradient update process. As explained here. Right? No?

There's an important distinction[1] to be made between these two claims:

A) Every function with large volume in parameter-space is simple

B) Every simple function has a large volume in parameter space

For a method of inference to qualify as a 'simplicity prior', you want both claims to hold. This is what lets us derive bounds like 'Solomonoff induction matches the performance of any computable predictor', since all of the simple, computable predictors have relatively large volume in the Solomonoff measure, so they'll be picked out after boundedly many mistakes. In particular, you want there to be an implication like, if a function has complexity , it will have parameter-volume at least .

Now, the Mingard results, at least the ones that have mathematical proof, rely on the Levin bound. This only shows (A), which is the direction that is much easier to prove -- it automatically holds for any mapping from parameter-space to functions with bounded complexity. They also have some empirical results that show there is substantial 'clustering', that is, there are some simple functions that have large volumes. But this still doesn't show that all of them do, and indeed is compatible with the learnable function class being extremely limited. For instance, this could easily be the case even if NTK/GP was only able to learn linear functions. In reality the NTK/GP is capable of approximating arbitrary functions on finite-dimensional inputs but, as I argued in another comment, this is not the right notion of 'universality' for classification problems. I strongly suspect[2] that the NTK/GP can be shown to not be 'universally data-efficient' as I outlined there, but as far as I'm aware no one's looked into the issue formally yet. Empirically, I think the results we have so far suggest that the NTK/GP is a decent first-order approximation for simple tasks that tends to perform worse on the more difficult problems that require non-trivial feature learning/efficiency.


  1. I actually posted basically the same thing underneath another one of your comments a few weeks ago, but maybe you didn't see it because it was only posted on LW, not the alignment forum ↩︎

  2. Basically, because in the NTK/GP limit the functions for all the neurons in a given layer are sampled from a single computable distribution, so I think you can show that the embedding is 'effectively finite' in some sense(although note it is a universal approximator for fixed input dimension) ↩︎

Ah, OK. Interesting, thanks. Would you agree with the following view:

"The NTK/GP stuff has neural nets implementing a "psuedosimplicity prior" which is maybe also a simplicity prior but might not be, the evidence is unclear. A psuedosimplicity prior is like a simplicity prior except that there are some important classes of kolmogorov-simple functions that don't get high prior / high measure."

Which would you say is more likely: The NTK/GP stuff is indeed not universally data efficient, and thus modern neural nets aren't either, or (b) NTK/GP stuff is indeed not universally data efficient, and thus modern neural nets aren't well-characterized by the NTK/GP stuff.

Yeah, that summary sounds right.

I'd say (b) -- it seems quite unlikely to me that the NTK/GP are universally data-efficient, while neural nets might be(although that's mostly speculation on my part). I think the lack of feature learning is a stronger argument that NTK/GP don't characterize neural nets well.

During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all, only the output function does.

Eh? Why does this follow? Derivatives make sense; the derivatives staying approximately-constant is one of the assumptions underlying NTK to begin with. But the functions computed by individual neurons should be able to change for exactly the same reason the output function changes, assuming the network has more than one layer. What am I missing here?

The asymmetry between the output function and the intermediate neuron functions comes from backprop -- from the fact that the gradients are backprop-ed through weight matrices with entries of magnitude O(). So the gradient of the output w.r.t itself is obviously 1, then the gradient of the output w.r.t each neuron in the preceding layer is O(), since you're just multiplying by a vector with those entries. Then by induction all other preceding layers' gradients are the sum of N random things of size O(1/N), and so are of size O() again. So taking a step of backprop will change the output function by O(1) but the intermediate functions by O(), vanishing in the large-width limit.

(This is kind of an oversimplification since it is possible to have changing intermediate functions while doing backprop, as mentioned in the linked paper. But this is the essence of why it's possible in some limits to move around using backprop without changing the intermediate neurons)

Ok, that's at least a plausible argument, although there are some big loopholes. Main problem which jumps out to me: what happens after one step of backprop is not the relevant question. One step of backprop is not enough to solve a set of linear equations (i.e. to achieve perfect prediction on the training set); the relevant question is what happens after one step of Newton's method, or after enough steps of gradient descent to achieve convergence.

What would convince me more is an empirical result - i.e. looking at the internals of an actual NTK model, trying the sort of tricks which work well for interpreting normal NNs, and seeing how well they work. Just relying on proofs makes it way too easy for an inaccurate assumption to sneak in - like the assumption that we're only using one step of backprop. If anyone has tried that sort of empirical work, I'd be interested to hear what it found.

The result that NTK does not learn features in the large N limit is not in dispute at all -- it's right there on page 15 of the original NTK paper, and indeed holds after arbitrarily many steps of backprop. I don't think that there's really much room for loopholes in the math here. See Greg Yang's paper for a lengthy proof that this holds for all architectures. Also worth noting that when people 'take the NTK limit' they often don't initialize an actual net at all, they instead use analytical expressions for what the inner product of the gradients would be at N=infinity to compute the kernel directly.

Alright, I buy the argument on page 15 of the original NTK paper.

I'm still very skeptical of the interpretation of this as "NTK models can't learn features". In general, when someone proves some interesting result which seems to contradict some combination of empirical results, my default assumption is that the proven result is being interpreted incorrectly, so I have a high prior that that's what's happening here. In this case, it could be that e.g. the "features" relevant to things like transfer learning are not individual neuron activations - e.g. IIRC much of the circuit interpretability work involves linear combinations of activations, which would indeed circumvent this theorem.

This whole class of concerns would be ruled out by empirical results - e.g. experimental evidence on transfer learning with NTKs, or someone applying the same circuit interpretability techniques to NTKs which are applied to standard nets.

I don't think taking linear combinations will help, because adding terms to the linear combination will also increase the magnitude of the original activation vector -- e.g. if you add together units, the magnitude of the sum of their original activations will with high probability be , dwarfing the O(1) change due to change in the activations. But regardless, it can't help with transfer learning at all, since the tangent kernel(which determines learning in this regime) doesn't change by definition.

What empirical results do you think are being contradicted? As far as I can tell, the empirical results we have are 'NTK/GP have similar performance to neural nets on some, but not all, tasks'. I don't think transfer/feature learning is addressed at all. You might say these results are suggestive evidence that NTK/GP captures everything important about neural nets, but this is precisely what is being disputed with the transfer learning arguments.

I can imagine doing an experiment where we find the 'empirical tangent kernel' of some finite neural net at initialization, solve the linear system, and then analyze the activations of the resulting network. But it's worth noting that this is not what is usually meant by 'NTK' -- that usually includes taking the infinite-width limit at the same time. And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD. That's what the above mathematical results mean -- the same mathematical analysis that implies that network training is like solving a linear system, also implies that the activations don't change at all.

They wouldn't be random linear combinations, so the central limit theorem estimate wouldn't directly apply. E.g. this circuit transparency work basically ran PCA on activations. It's not immediately obvious to me what the right big-O estimate would be, but intuitively, I'd expect the PCA to pick out exactly those components dominated by change in activations - since those will be the components which involve large correlations in the activation patterns across data points (at least that's my intuition).

I think this claim is basically wrong:

And to the extent that we expect the activations to change at all, we no longer have reason to think that this linear system is a good approximation of SGD.

There's a very big difference between "no change to first/second order" and "no change". Even in the limit, we do expect most linear combinations of the activations to change. And those are exactly the changes which would potentially be useful for transfer learning. And the tangent kernel not changing does not imply that transfer learning won't work, for two reasons: starting at a better point can accelerate convergence, and (probably more relevant) the starting point can influence the solution chosen when the linear system is underdetermined (which it is, if I understand things correctly).

I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task. If that's true, and NTKs can't be used for transfer learning, then that would imply that transfer learning in normal nets works for completely different reasons from good performance on the original task, and that good performance on the original task has nothing to do with learning features. Those both strike me as less plausible than these proofs about "NTK not learning features" being misinterpreted.

(I also did a quick google search for transfer learning with NTKs. I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.)

BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.

Hmm, so regarding the linear combinations, it's true that there are some linear combinations that will change by in the large-width limit -- just use the vector of partial derivatives of the output at some particular input, this sum will change by the amount that the output function moves during the regression. Indeed, I suspect(but don't have a proof) that these particular combinations will span the space of linear combinations that change non-trivially during training. I would dispute "we expect most linear combinations to change" though -- the CLT argument implies that we should expect almost all combinations to not appreciably change. Not sure what effect this would have on the PCA and still think it's plausible that it doesn't change at all(actually, I think Greg Yang states that it doesn't change in section 9 of his paper, haven't read that part super carefully though)

And the tangent kernel not changing does not imply that transfer learning won’t work

So I think I was a bit careless in saying that the NTK can't do transfer learning at all -- a more exact statement might be "the NTK does the minimal amount of transfer learning possible". What I mean by this is, any learning algorithm can do transfer learning if the task we are 'transferring' to is sufficiently similar to the original task -- for instance, if it's just the exact same task but with a different data sample. I claim that the 'transfer learning' the NTK does is of this sort. As you say, since the tangent kernel doesn't change at all, the net effect is to move where the network starts in the tangent space. Disregarding convergence speed, the impact this has on generalization is determined by the values set by the old function on axes of the NTK outside of the span of the partial derivatives at the new function's data points. This means that, for the NTK to transfer anything from one task to another, it's not enough for the tasks to both feature, for instance, eyes. It's that the eyes have to correlate with the output in the exact same way in both tasks. Indeed, the transfer learning could actually hurt the generalization. Nor is its effect invariant under simple transformations like flipping the sign of the target function(this would change beneficial transfer to harmful). By default, for functions that aren't simple multiples, I expect the linear correlation between values on different axes to be about 0, even if the functions share many meaningful features. So while the NTK can do 'transfer learning' in a sense, it's about as weak as possible, and I strongly doubt that this sort of transfer is sufficient to explain transfer learning's successes in practice(but don't have empirical proof).

I do think the empirical results pretty strongly suggest that the NTK/GP model captures everything important about neural nets, at least in terms of their performance on the original task.

It's true that NTK/GP perform pretty closely to finite nets on the tasks we've tried them on so far, but those tasks are pretty simple and we already had decent non-NN solutions. Generally the pattern is '"GP matches NNs on really simple tasks, NTK on somewhat harder ones". I think the data we have is consistent with this breaking down as we move to the harder problems that have no good non-NN solutions. I would be very interested in seeing an experiment with NTK on, say, ImageNet for this reason, but as far as I know no one's done so because of the prohibitive computational cost.

I only found one directly-relevant study, which is on way too small and simple a system for me to draw much of a conclusion from it, but it does seem to have worked.

Thanks for the link -- will read this tomorrow.

BTW, thanks for humoring me throughout this thread. This is really useful, and my understanding is updating considerably.

And thank you for engaging in detail -- I have also found this very helpful in forcing me to clarify(partially to myself) what my actual beliefs are.

So I read through the Maddox et al. study, and it definitely does not show that the NTK can do transfer learning. They pre-train using SGD on a single task, then use the NTK computed on the trained network to do Bayesian inference on some other tasks. They say in a footnote on page 9, "Note that in theory, there is no need to train the network at all. We found that it is practically useful to train the network to learn good representations." This makes me suspect that they tried using the NTK to learn the transfer parameters but it didn't work.

Regarding the empirical results about the NTK explaining the performance of neural nets, I found this study interesting. They computed the 'empirical NTK' on some finite-width networks and compared the performance of the solution found by SGD to that found by solving the NTK. For standard widths, the NTK solution performed substantially worse(up to 20% drop in accuracy). The gap closed to some extent, but not completely, upon making the network much wider. The size of the gap also correlated with the complexity of the task(0.5% gap for MNIST, 13% for CIFAR, 18% for a subset of ImageNet). The trajectory of the weights also diverged substantially from the NTK prediction, even on MNIST. All of this seems consistent with the NTK being a decent first-order approximation that breaks down on the really hard tasks that require the networks to do non-trivial feature learning.

Ah, that is interesting. This definitely updates me moderately toward the "NTKs don't learn features" hypothesis.

BTW, does this hypothesis also mean that feature learning should break down in ordinary nets as they scale up? Or does increasing the data alongside the parameter count counteract that?

I think nets are usually increased in depth as well as width when they are 'scaled up', so the NTK limit doesn't apply -- the convergence to NTK is controlled by the ratio of depth to width, only approaching a deterministic kernel if this ratio approaches 0.

[Deleted]

NTK doesn’t learn features because the feature class at initialization is a universal class

I've never heard of any result suggesting this, what's your argument? I suspect the opposite -- by the central limit theorem the partial derivatives and activations at each layer tend toward samples from a fixed distribution(differing per layer but fixed across neurons). I think this means that the NTK embedding is 'essentially finite' and actually not universal(though I'm not sure). Note that to show universality it's not enough to show that all embeddings can be found, you'll also need an argument showing that their density in the NTK embedding is bounded above zero.

[Deleted]

There's a big difference between 'universal learner' and 'fits any smooth function on a fixed input space'. The 'universal learner' property is about data efficiency: do you have bounded regret compared to any learning algorithm in some wide class? Solomonoff induction has this property with respect to computable predictors on binary strings, for instance. There are lots of learning algorithms able to fit any finite binary sequence but which are not universal. I haven't seen a good formalism for this in the neural net case, but I think it would involve letting the input dimension increase with the number of data points, and comparing the asymptotic performance of various algorithms.

[Deleted]

Ah, rereading your original comment more carefully I see that you indeed didn't say anything about 'universal learning'. You're quite right that the NTK is a universal function approximator. My apologies.

However, I still disagree that this is the reason that the NTK doesn't learn features. I think that 'universal function approximation' and 'feature learning' are basically unrelated dimensions along which a learning algorithm can vary. That is, it's quite possible to imagine a learning algorithm which constructs a sequence of different embeddings, all of which are universal approximators. The paper by Greg Yang I linked gives an example of such an algorithm(I don't think he explicitly proves this but I'm pretty sure it's true)

What I was trying to get at with the 'universal learning' remarks is that, although the NTK does indeed contain all finite embeddings, I believe that it does not do so in a very efficient way -- it might require disproportionately many training points to pick out what are, intuitively, fairly simple embeddings. I believe this is what is behind the poor performance of empirical NTKs compared to SGD-trained nets, as I brought up in this comment, and ultimately explains why algorithms that do 'feature learning' can outperform those that don't -- the feature learning algorithms are able to find more efficient embeddings for a given set of inputs(of course, it's possible to imagine a fixed embedding that's 'optimally efficient' in some way, but as far as I'm aware the NTK has no such property). This issue of 'embedding efficiency' seems only loosely related to the universal approximation property. To formalize this, it would be nice to develop a theory of universal inference in the setting of classification problems akin to Solomonoff induction. To effectively model this in an asymptotic theory, I think it might be necessary to increase the dimension of the model input along with the number of data points, since otherwise all universal approximators for a given dimension will have asymptotically the same performance. Everything in this paragraph is just my personal speculation though, as far as I'm aware there's no existing theory of universal inference in classification problems, so if you found my remarks confusing that's pretty understandable :)

[Deleted]

By universal approximation, these features will be sufficient for any downstream learning task

Right, but trying to fit an unknown function with linear combinations of those features might be extremely data-inefficient, such that it is basically unusable for difficult tasks. Of course you could do better if you're not restricted to linear combinations -- for instance, if the map is injective you could invert back to the original space and apply whatever algorithm you wanted. But at that point you're not really using the Fourier features at all. In particular, the NTK always learns a linear combination of its features, so it's the efficiency of linear combinations that's relevant here.

I agree that there is no learning taking place and that such a method may be inefficient. However, that goes beyond my original objection.

You originally said that the NTK doesn't learn features because its feature class already has a good representation at initialization. What I was trying to convey (rather unclearly, admittedly) in response is:

A) There exist learning algorithms that have universal-approximating embeddings at initialization yet learn features. If we have an example of P and !Q, P-->Q cannot hold in general, so I don't think it's right to say that the NTK's lack of feature learning is due to its universal-approximating property.

B) Although the NTK's representation may be capable of approximating arbitrary functions, it will probably be very slow at learning some of them, perhaps so slow that using it is infeasible. So I would dispute that it already has 'good' representations. While it's universal in one sense, there might be some other sense of 'universal efficiency' in which it's lacking, and where feature-learning algorithms can outperform it.

This is not a trivial question

I agree that in practice there's likely to be some relationship between universal approximation and efficiency, I just think it's worth distinguishing them conceptually. Thanks for the paper link BTW, it looks interesting.

Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all.
Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity.

I used to think it would be extremely unlikely for a randomly initialized neural net to contain a subnetwork that performs just as well as the entire neural net does after training. But the multi-prize lottery ticket results seem to show just that. So now I don't know what to think when it comes to what sorts of things are likely or unlikely when it comes to this stuff. In particular, is it really so unlikely that 'car detector' functions really do exist somewhere in the random jumble of a sufficiently big randomly initialized NN? Or maybe they don't exist right away, but with very slight tweaks they do?

They would exist in a sufficiently big random NN, but their density would be extremely low I think. Like, if you train a normal neural net with a 15000 neurons and then there's a car detector, the density of car detectors is now 1/15000. Whereas I think the density at initialization is probably more like 1/2^50 or something like that(numbers completely made up), so they'd have a negligible effect on the NTK's learning ability('slight tweaks' can't happen in the NTK regime since no intermediate functions change by definition)

 A difference with the pruning case is that the number of possible prunings increases exponentially with the number of neurons but the number of neurons is linear. My take on the LTH is that pruning is basically just a weird way of doing optimization so it's not that surprising you can get good performance.

My take on the LTH is that pruning is basically just a weird way of doing optimization so it's not that surprising you can get good performance.

+1 to this in particular; I think this is the main point Daniel (and many people like Daniel) are missing here. There's a very big difference between "car detector functions exist somewhere in the random jumble of a sufficiently big randomly initialized NN" vs "the net can be pruned to yield a car detector function", and the LTH papers show the latter.

I think I get this distinction; I realize the NN papers show the latter; I guess our disagreement is about how big a deal / how surprising this is.

Can we therefore model fine-tuning as moving around in the parameter tangent space around the pre-trained network?

Yes, and indeed in the NTK limit we can model ordinary training that way.

There is an extensive discussion about feature learning in relation to the aforementioned Mingard et al result in the comments of this post. The conclusion of the discussion was that feature learning is uncoupled from inductive bias for infinite (and actually finite width with further conditons) neural networks when trained by a random-sampling process (essentially how NNGPs work).

The open question is whether the probability distribution over functions after each layer are the same whether you train with SGD or random sampling. Given how the posteriors of optimiser trained NNs are to NNGPs, I think it is sensible to assume that they are similar. However, the important question is still whether this scales to large architectures and datasets, which become computationally much harder to test (as the NNGP kernel becomes harder and harder to compute with size of dataset).