Anthropic's recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability. 

While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.

1. Directly Analyzing the Critical Points of a Linear Toy Model

Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper "Neural networks and principal component analysis: Learning from examples without local minima." by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).

Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have  neurons and that the middle layer has  neurons.  This means that the function that the model is implementing is of the form , where  is a  matrix, and  is a  matrix. That is, the matrix  contains the weights of the connections between the input layer and the hidden layer, and the matrix  is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though - for a given set of weights - the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.

The error on a given input  will be measured by  and on the data set , the total loss is 

Define  to be the matrix whose  entry  is given by 

Clearly this matrix is symmetric.  

Assumption. We will assume that the data is such that a)  is invertible and b)  has distinct eigenvalues. 

Let  be the eigenvalues of .

 

1.1 The Global Minimum


Proposition 1. (Characterization of Critical Points) Fix the dataset and consider  to be a function of the two matrix variables  and . For any critical point  of , there is a subset  of size for which 

  1.  is an orthogonal projection onto a -dimensional subspace spanned by orthonormal eigenvectors of  corresponding to the eigenvalues ; and
  2. .

Corollary 2. (Characterization of the Minimum) The loss has a unique minimum value that is attained when , which corresponds to the situation when   is an orthogonal projection onto the -dimensional subspace spanned by the eigendirections of  that have the largest eigenvalues. 

Remarks. We won't try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.

We begin by directly differentiating  with respect to the entries of  and . Using summation convention on repeated indices, we first take the derivative with respect to 

 

Setting this equal to zero and interpreting this equation for all  and  gives us that 

Then, separately, we differentiate  with respect to  : 

Setting this equation equal to zero for every  and  we have that: 

Thus 

Since we have assumed that  is invertible, the first equation immediately implies that . If we assume in addition that  has full rank (a reasonable assumption in any case of practical interest), then  is invertible and we have that 

which in turn implies that 

where we have written  to denote the orthogonal projection on to the column space of .  

Claim. We next claim that  commutes with .

Proof of claim. Plugging (5) into (3), we have:

Then, right-multiply by  and use the fact that  to get:

The right-hand side is manifestly a symmetric matrix, so we deduce that  is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that  commutes with  and completes the proof of the claim.

 

Now let  be the orthogonal matrix which diagonalizes ,  i.e. the matrix for which 

where  is a diagonal matrix with entries 

 

Claim. We next claim that  and that  is diagonal.  

Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have 

 which implies that

To show that  is diagonal, we show that it commutes with the diagonal matrix  (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from , we first insert the identity matrix in the form , and then use (8) and (9) thus: 

Then recall that we have already established that  commutes with . So we can swap them and then performing the same trick in reverse: 

This shows that  commutes with  and completes the proof of the claim.

 

So, given that  is an orthogonal projection of rank  and is diagonal, there exists a set of indices  with  such that the  entry of  is zero if  and 1 if . And since , we see that 

where  is formed from  by simply setting to zero the  column if . This is manifestly an orthogonal projection onto the span of , where  is an orthonormal basis of eigenvectors of  (and indeed the columns of ).  Combining these observations with (5), we have that

This proves the first claim of the proposition.

 

To prove the second part, write  and compute thus: 

But we know from (7) and (11) that  and so this last line is actually just equal to

Focussing on the second term and using (11), then (9) and (8), then cancelling , and then - to reach the last line - cyclicly permuting the matrices inside the trace operator to produce another  cancellation, we have:

The diagonal form of  means that this final expression is equal to , meaning that 

Since  (the trace is always equal to the sum of the eigenvalues), this completes the proof of the proposition. 

Remarks. Equation (10) above tells us that , which means that there exists an invertible matrix  with . Then, using (4), we compute that 

So we have:

 

1.2 Characterizing Other Critical Points

This subsection is something of an aside, but it is included for completeness.

Proposition 3. (Other Critical Points are Saddle Points.) Fix the dataset and consider  to be a function of the two matrix variables  and . Every other critical point is a saddle point, i.e. if  is a critical point but not equal to the unique minimum, then exist  and  which are arbitrarily close to  and  respectively and at which a lower loss is achieved. 

Proof. Since  is not the unique global minimum, we know from Corollary 2 that . This means that there are distinct indices  and  for which  and . In particular, bear in mind that 

Now, given any , put 

And let us form the new matrix  by starting with  and replacing the column  with . Write 

We want to calculate the loss of the model at . We ought to bear in mind that it is not a critical point, so we cannot assume the intermediate results in the proof of Proposition 2, but it turns out that the bits that are most useful for this computation rely only on algebra and (13), (14). We start from the equivalent of line (11) which is that   which implies that . And so just as in (12) above, we have 

Now, looking at the final term on the right-hand side, we have  and (by cycling permutation) . And since 

we have:

We also use  and (16) on the second term on the right-hand side of (15) to ultimately arrive at:

So we are interested in computing the diagonal elements of . Fix . The  diagonal entry is given by: 

This can be computed directly from the definition of  to give that the  entry on the diagonal is equal to

Therefore 

Since  this shows that in an arbitrarily small neighbourhood of the critical point  we can find a point  where smaller loss is achieved. We will not bother doing so here, but one can also check that  is not a local maxima by using the fact that for fixed (full rank) , the function  is convex. 

 

2. Sparse Data, Weight Tying, and Gradients

Abstractly analyzing critical points is not at all the same as training real models. In this section we start to think about data and the optimization process.

2.1 Sparse Synthetic Data

Here we describe the kind of training data used in Anthropic's toy experiments

Fix a number . This parameter is the sparsity of the data. We will typically be most interested in the case where  is close to 1. 

Let  be an independent and identically distributed family of Bernoulli random variables with parameter . And let  be an IID family of  random variables. Write  and . Our datasets  will be drawn from the IID family . Notice that

  • Independently, for each data point  and for each , we will have 
  • So, the expected number of non-zero entries for each data point is . To bring this in line with the way people say things like "-sparse", we can say that the data is, on average, -sparse.
  • .

Remark. Judging from some of the existing literature on the linear model that we analyze in Section 1 (e.g. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks." by Andrew M. Saxe, James L. McClelland and Surya Ganguli), it seems like it's tempting to make an assumption/simplification/approximation that . I still don't feel like I understand how justifiable that is - for me this question is a potential 'jumping-off' point for further analysis of the whole problem. Recall that the matrix  is equal to  .  Certainly the probability that an off-diagonal entry of  is equal to zero is  whereas for the diagonal entries it is just . And note that  if  and .   But the diagonal entries are still independent and I'm not sure why thinking of them as equal makes sense. 

The data (and the loss) are model two main ideas: Firstly, that the coordinate directions of the input space act as a natural set of features for the data. And secondly, when  is close to 1, the sparsity of the data is supposed to capture the fact that features really do often tend to be sparse in real-world data, i.e. we see that for any given object or any given word/idea that appears in a language, it is the case that most images don't contain that object and most sentences don't contain that word or idea. 

2.2 Weight Tying and The Gradient Flow

In practice, when we train an autoencoder like this, we do so with weight tying. Roughly speaking, this means that we only consider the case where . Proposition 1 does indeed allow for a global minimum in which : This is achieved by essentially taking  in the equations () at the end of Section 1.1, i.e. we have:

But note that we don't actually want to try to repeat the analysis of Section 1 on a loss of the form . This would be a higher-order polynomial function of the entries of  and so it's genuinely a different and potentially more complicated functional. The way that weight-tying is done in practice is more similar to saying that we insist during training that updates are made that preserve the equality .

Equations (1) and (2) in subsection 1.1 are obtained as a direct result of differentiating the loss with respect to individual entries of the matrices (or individual 'weights' if we interpret this model as a feedforward neural network without activations). Our computations show that:

In an appropriate continuous time limit, if we set the learning rate to 1, the weights during training evolve according to the differential equations:

 

Remarks. Notice that there is a certain deliberate sloppiness here: One doesn't really have a fixed matrix  and then run this gradient flow for all time; the matrix  is a function of (a batch of) training data. So we need to be careful about any further manipulations or interpretations of these equations. 

Those caveats having been noted, if we additionally add in the weight-tying constraint  , we get:

We can even make the substitution  to introduce the form:

 In components (and without summation convention) the equation reads 

Let  denote the set of columns of  so that (24) can becomes:

Expanding the brackets and executing the sum over  gives:

Then the sum over  further simplifies the first term to give: 

Finally, just peel off the  term from the remaining summation to arrive at the equation 

Remarks. (cf. the previous two Remarks)  If we assume that , then  and the equation above arises as gradient descent on the energy functional 

It's plausible that a reasonable line of argument to justify this is that since no particular directions in the data are special, it means that over time, on average, the effects of different eigenvalues of  just somehow 'average out'. But I don't endorse or understand how that argument would actually go. Regardless, if we just assume this for now, as is explained in the Anthropic paper, we can think of the two terms in (28) as being in competition. The first term suggests that model 'wants' to learn the  feature by arranging . However, as it tries to does so, it incurs a penalty - given by   - that can reasonably be interpreted as the extent to which the hidden representation  of that feature interferes with its attempts to represent and reconstruct the other features. 

3. The ReLU Output Model 

3.1 The Distribution of the Data and the Integral Loss

Perhaps a better way to try to incorporate information about the distribution of the data into the analysis here is to directly let  be the distribution (i.e. in the proper measure-theoretic since) of  on  and to consider

In the Anthropic paper and in my own work, we are ultimately more interested in a model with biases and ReLUs at the output layer. 

Performing an analysis anything like that done in Section 1 seems much harder for this model, but perhaps more progress can be made studying the integral above. 

The synthetic data we described in the previous section is all contained in the cube  . In the sparse regime i.e. with  close to 1, the vast majority of the data is concentrated around the lower-dimensional skeletons of the cube. For , if we write  for the set of points in the cube with only  non-zero entries, i.e.

then  is the disjoint union  

Without a closer analysis of binomial tail bounds I can't immediately tell how well-justified it is to say, ignore  and focus the analysis just on 1-sparse vectors in the dataset. i.e. You might want to say that  is sufficiently small such that that region contributes only negligibly to the integral. Then you can start to work with more manageable expressions To my mind this is another concrete potential 'jumping-off' point if one were to do more investigation. In particular, it is in the direction of the observations made in Toy Models of Superposition to suggest a link between this problem and 'Thomson Problem'.

New Comment
1 comment, sorted by Click to highlight new comments since:

Indeed the integrals in the sparse case aren't so bad https://arxiv.org/abs/2310.06301. I don't think the analogy to the Thompson problem is correct, it's similar but qualitatively different (there is a large literature on tight frames that is arguably more relevant).