Indeed the integrals in the sparse case aren't so bad https://arxiv.org/abs/2310.06301. I don't think the analogy to the Thompson problem is correct, it's similar but qualitatively different (there is a large literature on tight frames that is arguably more relevant).
Anthropic's recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability.
While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.
1. Directly Analyzing the Critical Points of a Linear Toy Model
Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper "Neural networks and principal component analysis: Learning from examples without local minima." by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).
Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have D neurons and that the middle layer has d<D neurons. This means that the function that the model is implementing is of the form x↦ABx, where x∈RD, B is a d×D matrix, and A is a D×d matrix. That is, the matrix B contains the weights of the connections between the input layer and the hidden layer, and the matrix A is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though - for a given set of weights - the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.
The error on a given input x will be measured by ∥∥x−ABx∥∥2 and on the data set {xt}Tt=1, the total loss is
L=L(A,B,{xt}Tt=1):=T∑t=1∥∥xt−ABxt∥∥2=T∑t=1D∑i=1(xti−D∑j,k=1aijbjkxtk)2Define Σ to be the matrix whose (i,j)th entry σij is given by
σij=T∑t=1xtixtj.Clearly this matrix is symmetric.
Assumption. We will assume that the data is such that a) Σ is invertible and b) Σ has distinct eigenvalues.
Let λ1>⋯>λD be the eigenvalues of Σ.
1.1 The Global Minimum
Proposition 1. (Characterization of Critical Points) Fix the dataset and consider L to be a function of the two matrix variables A and B. For any critical point (A,B) of L, there is a subset I⊂{1,…,D} of size dfor which
Corollary 2. (Characterization of the Minimum) The loss has a unique minimum value that is attained when I={1,…,d}, which corresponds to the situation when AB is an orthogonal projection onto the d-dimensional subspace spanned by the eigendirections of Σ that have the largest eigenvalues.
Remarks. We won't try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.
We begin by directly differentiating L with respect to the entries of A and B. Using summation convention on repeated indices, we first take the derivative with respect to bj′k′ :
∂L∂bj′k′=T∑t=1D∑i=1−2(xti−aijbjkxtk)ailδlj′δqk′xtq=−2T∑t=1(xtiaij′xtk′−aijbjkxtkaij′xtk′)Setting this equal to zero and interpreting this equation for all j′=1,…,d and k′=1,…,D gives us that
ATΣ=ATABΣ.(1)Then, separately, we differentiate L with respect to ai′j′ :
∂L∂ai′j′=T∑t=1D∑i=1−2(xti−aijbjkxtk)δii′δpj′bpqxtq=−2T∑t=1(xti′bj′qxtq−ai′jbjkxtkbj′qxtq).Setting this equation equal to zero for every i′=1,…,D and j′=1,…,d we have that:
ΣBT=ABΣBT.(2)Thus
∇L(A,B)=0⟺{ATΣ=ATABΣΣBT=ABΣBT.(3)Since we have assumed that Σ is invertible, the first equation immediately implies that AT=ATAB. If we assume in addition that A has full rank (a reasonable assumption in any case of practical interest), then ATA is invertible and we have that
(ATA)−1AT=B,(4)which in turn implies that
AB=A(ATA)−1AT=PA,(5)where we have written PA to denote the orthogonal projection on to the column space of A.
Claim. We next claim that Σ commutes with PA.
Proof of claim. Plugging (5) into (3), we have:
ΣBT=PAΣBT.(6)Then, right-multiply by AT and use the fact that PTA=PA to get:
ΣPA=PAΣPA.(7)The right-hand side is manifestly a symmetric matrix, so we deduce that ΣPA is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that Σ commutes with PA and completes the proof of the claim.
Now let U be the orthogonal matrix which diagonalizes Σ, i.e. the matrix for which
Σ=UΛUT,(8)where Λ is a diagonal matrix with entries λ1>λ2>⋯>λD>0.
Claim. We next claim that PA=UPUTAUT and that PUTA is diagonal.
Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have
PUTA=UTA(ATUUTA)−1ATU=UTA(ATA)−1ATU=UTPAU,which implies that
PA=UPUTAUT.(9)To show that PUTA is diagonal, we show that it commutes with the diagonal matrix Λ (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from PUTAΛ, we first insert the identity matrix in the form UTU, and then use (8) and (9) thus:
PUTAΛ=UTUPUTAUTUΛUTU=UTPAΣUThen recall that we have already established that PA commutes with Σ. So we can swap them and then performing the same trick in reverse:
UTPAΣU=UTΣPAU=UTUΛUTUPUTAUTU=ΛPUTA.This shows that PUTA commutes with Λ and completes the proof of the claim.
So, given that PUTA is an orthogonal projection of rank d and is diagonal, there exists a set of indices I={i1,…,id} with 1≤i1<i2<⋯<id≤D such that the (i,j)th entry of PUTA is zero if i≠j and 1 if i=j and i∈I. And since PA=UPUTAUT, we see that
PA=UIUTI,(10)where UI is formed from U by simply setting to zero the jth column if j∉I. This is manifestly an orthogonal projection onto the span of {ui1,…,uid}, where u1,u2,…,uD is an orthonormal basis of eigenvectors of Σ (and indeed the columns of U). Combining these observations with (5), we have that
AB=PA=UIUTI=PUI.(11)This proves the first claim of the proposition.
To prove the second part, write AB=[pij] and compute thus:
T∑t=1∥∥xt−ABxt∥∥2=T∑t=1D∑i=1(xti−pijxtj)2=T∑t=1(xtixti−2xtipijxtj+pijxtjpikxtk)=trΣ−2tr(PUIΣ)+tr(PUIΣPTUI).(12)But we know from (7) and (11) that PUIΣPUI=ΣPUI and so this last line is actually just equal to
trΣ−tr(PUIΣ).Focussing on the second term and using (11), then (9) and (8), then cancelling UTU=I, and then - to reach the last line - cyclicly permuting the matrices inside the trace operator to produce another UTU cancellation, we have:
tr(PUIΣ)=tr(PAΣ)=tr(UPUTAUTUΛUT)=tr(UPUTAΛUT)=tr(PUTAΛ).The diagonal form of PUTA means that this final expression is equal to ∑i∈Iλi, meaning that
T∑t=1∥∥xt−ABxt∥∥2=trΣ−∑i∈Iλi.Since trΣ=∑Di=1λi (the trace is always equal to the sum of the eigenvalues), this completes the proof of the proposition. □
Remarks. Equation (10) above tells us that col(A)=span⟨ui1,…,uid⟩, which means that there exists an invertible matrix C with A=UIC. Then, using (4), we compute that
B=(ATA)−1AT=(CTUTIUIC)−1(UIC)T=C−1(CT)−1CTUTI=C−1UTI.So we have:
{A=UICB=C−1UTI(⋆)1.2 Characterizing Other Critical Points
This subsection is something of an aside, but it is included for completeness.
Proposition 3. (Other Critical Points are Saddle Points.) Fix the dataset and consider L to be a function of the two matrix variables A and B. Every other critical point is a saddle point, i.e. if (A,B) is a critical point but not equal to the unique minimum, then exist ~A and ~B which are arbitrarily close to A and B respectively and at which a lower loss is achieved.
Proof. Since (A,B) is not the unique global minimum, we know from Corollary 2 that I≠{1,…,d}. This means that there are distinct indices j and k for which j∈I, k∉I and k<j. In particular, bear in mind that λk>λj.
Now, given any ϵ>0, put
~uj:=uj+ϵuk√1+ϵ2.And let us form the new matrix ~UI by starting with UI and replacing the column uj with ~uj. Write
~A=~UIC~B=C−1~UTI(13)(14)We want to calculate the loss of the model at (~A,~B). We ought to bear in mind that it is not a critical point, so we cannot assume the intermediate results in the proof of Proposition 2, but it turns out that the bits that are most useful for this computation rely only on algebra and (13), (14). We start from the equivalent of line (11) which is that ~A~B=~UI~UTI=P~UI, which implies that ~A~B=P~A. And so just as in (12) above, we have
T∑t=1∥∥xt−~A~Bxt∥∥2=trΣ−2tr(P~UIΣ)+tr(P~UIΣPT~UI).(15)Now, looking at the final term on the right-hand side, we have P~UIΣPT~UI=P~AΣP~A and (by cycling permutation) tr(P~AΣP~A)=tr(P~AΣ). And since
PUT~A=UT~A(~ATUUT~A)−1~ATU=UTP~AU,we have:
tr(P~AΣ)=tr(UPUT~AUTΣ)=tr(PUT~AΛ).(16)We also use tr(P~UIΣ)=tr(P~AΣ) and (16) on the second term on the right-hand side of (15) to ultimately arrive at:
T∑t=1∥∥xt−~A~Bxt∥∥2=trΣ−tr(PUT~AΛ).(17)So we are interested in computing the diagonal elements of PUT~A. Fix i∈{1,…,D}. The ith diagonal entry is given by:
eTiPUT~Aei=eTiUTP~AUei=uTi~UI~UTIui.This can be computed directly from the definition of ~UI to give that the ith entry on the diagonal is equal to
⎧⎪ ⎪ ⎪⎨⎪ ⎪ ⎪⎩0if i∉I∪{k}1if i∈I∖{j}1/(1+ϵ2)if i=jϵ2/(1+ϵ2)if i=k.(18)Therefore
T∑t=1∥∥xt−~A~Bxt∥∥2=trΣ−[∑i∈I∖{j}λi+λj/(1+ϵ2)+ϵ2λk/(1+ϵ2)]=trΣ−∑i∈Iλi−ϵ2(1+ϵ2)(λk−λj)=T∑t=1∥∥xt−ABxt∥∥2−ϵ2(1+ϵ2)(λk−λj).Since λk>λj this shows that in an arbitrarily small neighbourhood of the critical point (A,B) we can find a point (~A,~B) where smaller loss is achieved. We will not bother doing so here, but one can also check that (A,B) is not a local maxima by using the fact that for fixed (full rank) A, the function z↦∥x−Az∥2 is convex. □
2. Sparse Data, Weight Tying, and Gradients
Abstractly analyzing critical points is not at all the same as training real models. In this section we start to think about data and the optimization process.
2.1 Sparse Synthetic Data
Here we describe the kind of training data used in Anthropic's toy experiments
Fix a number S∈[0,1]. This parameter is the sparsity of the data. We will typically be most interested in the case where S is close to 1.
Let {Bti}∞t=1Di=1 be an independent and identically distributed family of Bernoulli random variables with parameter (1−S). And let {Uti}∞t=1Di=1 be an IID family of Uniform([0,1]) random variables. Write Xti=BtiUti and Xt:=(Xt1,…,XtD). Our datasets {xt}Tt=1∈RD will be drawn from the IID family {Xt}∞t=1. Notice that
Remark. Judging from some of the existing literature on the linear model that we analyze in Section 1 (e.g. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks." by Andrew M. Saxe, James L. McClelland and Surya Ganguli), it seems like it's tempting to make an assumption/simplification/approximation that Σ=I. I still don't feel like I understand how justifiable that is - for me this question is a potential 'jumping-off' point for further analysis of the whole problem. Recall that the matrix Σ is equal to ∑Tt=1xt⊗xt. Certainly the probability that an off-diagonal entry of xt⊗xt is equal to zero is 1−(1−S)2 whereas for the diagonal entries it is just S. And note that E(xtixtj)=14(1−S)2 if i≠j and E((xti)2)=13(1−S). But the diagonal entries are still independent and I'm not sure why thinking of them as equal makes sense.
The data (and the loss) are model two main ideas: Firstly, that the coordinate directions of the input space act as a natural set of features for the data. And secondly, when S is close to 1, the sparsity of the data is supposed to capture the fact that features really do often tend to be sparse in real-world data, i.e. we see that for any given object or any given word/idea that appears in a language, it is the case that most images don't contain that object and most sentences don't contain that word or idea.
2.2 Weight Tying and The Gradient Flow
In practice, when we train an autoencoder like this, we do so with weight tying. Roughly speaking, this means that we only consider the case where A=BT. Proposition 1 does indeed allow for a global minimum in which A=BT: This is achieved by essentially taking C=I in the equations (⋆) at the end of Section 1.1, i.e. we have:
A=UIB=UTIBut note that we don't actually want to try to repeat the analysis of Section 1 on a loss of the form ∑t∥xt−WTWxt∥2. This would be a higher-order polynomial function of the entries of W and so it's genuinely a different and potentially more complicated functional. The way that weight-tying is done in practice is more similar to saying that we insist during training that updates are made that preserve the equality A=BT.
Equations (1) and (2) in subsection 1.1 are obtained as a direct result of differentiating the loss with respect to individual entries of the matrices (or individual 'weights' if we interpret this model as a feedforward neural network without activations). Our computations show that:
∇L(A,B)=−2(ATΣ−ATABΣΣBT−ABΣBT).(19)In an appropriate continuous time limit, if we set the learning rate to 1, the weights during training evolve according to the differential equations:
ddtB=AT(Σ−ABΣ)ddtA=(Σ−ABΣ)BT.(20)(21)Remarks. Notice that there is a certain deliberate sloppiness here: One doesn't really have a fixed matrix Σ and then run this gradient flow for all time; the matrix Σ is a function of (a batch of) training data. So we need to be careful about any further manipulations or interpretations of these equations.
Those caveats having been noted, if we additionally add in the weight-tying constraint A=BT , we get:
ddtW=W(Σ−WTWΣ).(22)We can even make the substitution W=¯¯¯¯¯¯WUT to introduce the form:
ddt¯¯¯¯¯¯W=¯¯¯¯¯¯W(I−¯¯¯¯¯¯WT¯¯¯¯¯¯W)Λ.(23)In components (and without summation convention) the equation reads
ddt¯¯¯¯wij=D∑k,l=1¯¯¯¯wik(δkl−d∑m=1¯¯¯¯wmk¯¯¯¯wml)δljλj.(24)Let {¯¯¯¯wi}Di=1 denote the set of columns of ¯¯¯¯¯¯W so that (24) can becomes:
ddt¯¯¯¯wj=D∑k,l=1(δkl−⟨¯¯¯¯wk,¯¯¯¯wl⟩)δljλj¯¯¯¯wk.(25)Expanding the brackets and executing the sum over l gives:
ddt¯¯¯¯wj=D∑k=1δkjλj¯¯¯¯wk−D∑k=1⟨¯¯¯¯wk,¯¯¯¯wj⟩λj¯¯¯¯wk.(26)Then the sum over k further simplifies the first term to give:
ddt¯¯¯¯wj=λj¯¯¯¯wj−D∑k=1⟨¯¯¯¯wk,¯¯¯¯wj⟩λj¯¯¯¯wk.(27)Finally, just peel off the k=j term from the remaining summation to arrive at the equation
ddt¯¯¯¯wj=(1−|¯¯¯¯wj|2)λj¯¯¯¯wj−D∑k≠j⟨¯¯¯¯wk,¯¯¯¯wj⟩λj¯¯¯¯wk.(28)Remarks. (cf. the previous two Remarks) If we assume that Σ=I, then W=¯¯¯¯¯¯W and the equation above arises as gradient descent on the energy functional
E=14D∑i=1λ(1−|wi|2)2+12∑i,j : i≠jλ∣∣⟨wi,wj⟩∣∣2.(28)It's plausible that a reasonable line of argument to justify this is that since no particular directions in the data are special, it means that over time, on average, the effects of different eigenvalues of Σ just somehow 'average out'. But I don't endorse or understand how that argument would actually go. Regardless, if we just assume this for now, as is explained in the Anthropic paper, we can think of the two terms in (28) as being in competition. The first term suggests that model 'wants' to learn the kth feature by arranging |wk|=1. However, as it tries to does so, it incurs a penalty - given by ∑i≠kλ∣∣⟨wi,wk⟩∣∣2 - that can reasonably be interpreted as the extent to which the hidden representation wk∈Rd of that feature interferes with its attempts to represent and reconstruct the other features.
3. The ReLU Output Model
3.1 The Distribution of the Data and the Integral Loss
Perhaps a better way to try to incorporate information about the distribution of the data into the analysis here is to directly let μ be the distribution (i.e. in the proper measure-theoretic since) of Xt on RD and to consider
L=∫RD∥∥x−WTWx∥2dμ(x).In the Anthropic paper and in my own work, we are ultimately more interested in a model with biases and ReLUs at the output layer.
L=∫RD∥∥x−ReLU(WTWx−b)∥∥2dμ(x)Performing an analysis anything like that done in Section 1 seems much harder for this model, but perhaps more progress can be made studying the integral above.
The synthetic data we described in the previous section is all contained in the cube Q:=[0,1]D⊂RD . In the sparse regime i.e. with S close to 1, the vast majority of the data is concentrated around the lower-dimensional skeletons of the cube. For l=1,…,D, if we write Ql for the set of points in the cube with only l non-zero entries, i.e.
Ql:={x∈Q:#{j:xj≠0}=l},then C is the disjoint union
C=n⋃l=0QlWithout a closer analysis of binomial tail bounds I can't immediately tell how well-justified it is to say, ignore ∪l≥2Ql and focus the analysis just on 1-sparse vectors in the dataset. i.e. You might want to say that μ(⋃l≥2Ql) is sufficiently small such that that region contributes only negligibly to the integral. Then you can start to work with more manageable expressions To my mind this is another concrete potential 'jumping-off' point if one were to do more investigation. In particular, it is in the direction of the observations made in Toy Models of Superposition to suggest a link between this problem and 'Thomson Problem'.