A big bottleneck in interpretability is neural networks are non-local. That is, given the layer setup

if we change a small bit of the original activations, then a large bit of the new activations are affected.

This is an impediment to finding the circuit-structure of networks. It is difficult to figure out how something works when changing one thing affects everything.

The project I'm currently working on aims to fix this issue, without affecting the training dynamics of networks or the function which the network is implementing[1]. The idea is to find a rotation matrix  and insert it with its inverse like below, then group together the rotation with the original activations, and the inverse with the weights and nonlinear function.

We then can optimize the rotation matrix and its inverse so that local changes in the rotated activation matrix have local effects on the outputted activations. This locality is measured by the average sparsity of the jacobian across all the training inputs.

We do this because the jacobian is a representation of how each of the inputs affects each of the outputs. Large entries represent large effects. Small entries represent small effects. So if many entries are zero, this means that fewer inputs have an effect on fewer outputs. I.e. local changes to the input cause local changes to the output.

 

This should find us a representation of the activations and interpretations of matrix multiplies that "make sense" in the context of the rest of the network.  

Another way of thinking about this is that our goal is to find the basis our network is thinking in.

Currently I'm getting this method to work on a simple, 3-layer, fully connected MNIST number classifying network. If this seems to give insight into the mechanics of the network after application, the plan is to adapt it to a more complicated network such as a transformer or resnet.

I only have preliminary results right now, but they are looking promising:

This is the normalized jacobian the middle layer before a rough version of my method:

And here is the normalized jacobian after a rough version of my method (the jacobian's output has been set to a basis which maximizes it's sparsity):


Thanks David Udell for feedback on the post. I did not listen to everything you said, and if I did the post would have been better

  1. ^

    This seems important if we'd like to use interpretability work to produce useful conjectures about agency and selection more generally.

New Comment
48 comments, sorted by Click to highlight new comments since:

This is the jacobian taken at a single data point, right? Assuming so, you might want to try looking for a single rotation which makes the jacobian sparse at many datapoints simultaneously. That would be more directly interpretable as factoring the net into a relatively low number of information channels.

Another useful next step would be to take some part of the net which maps X -> Y -> Z, and compute the rotations which maximize sparsity for X -> Y and Y -> Z separately. Then, try to compose the rotations found. Do the "sparse output channels" of X -> Y approximately match the "sparse input channels" of Y -> Z?

I was planning on doing the first idea, and I do like the second idea! I'm slightly skeptical that the two rotations will be the same, but I did find that when performing the method on the last layer of the model, I get the identity matrix, which is some evidence in favor of the 'rotations are the same' prediction being right in general.

I think the idea is that if the rotated basis fundamentally "means" something important, rather than just making what's happening easier to picture for us humans, we'd kind of expect the basis computed for X->Y to mostly match the basis for Y->Z. 

At least that's the sort of thing I'd expect to see in such a world.

Yup, this is why I'm skeptical there will be a positive result. I did not try to derive a principled, meaningful, basis. I tried the most obvious thing to do which nobody else seems to have done. So I expect this device will be useful and potentially the start of something fundamental, but not fundamental itself.

Would love to see more in this line of work.

We then can optimize the rotation matrix and its inverse so that local changes in the rotated activation matrix have local effects on the outputted activations.

Could you explain how you are formulating/solving this optimization problem in more detail?

Suppose our  model has the following format:

where  are matrix multiplies, and  is our nonlinear layer.

We also define a sparsity measure to minimize, chosen for the fun property that it really really really likes zeros compared to almost all other numbers.

note that lower sparsity according to this measure means more zeros.

There are two reasonable ways of finding the right rotations. I will describe one way in depth, and the other way not-so in depth. Do note that the specifics of all this may change once I run a few experiments to determine whether there's any short-cuts I'm able to take[1].

We know the input is in a preferred basis. In our MNIST case, it is just the pixels on the screen. These likely interact locally because the relevant first-level features are local. If you want to find a line in the bottom right, you don't care about the existence of white pixels in the top left.

We choose our first rotation  so as to minimize 

where

Then the second rotation  so as to minimize

where

And finally choosing  so as to minimize 

where

.

The other way of doing this is to suppose the output is in a preferred basis, instead of the input.

Currently I'm doing this minimization using gradient descent (lr = 0.0001), and parameterizing my rotation matrices using the fact that if  is an antisymmetric matrix[2], then  is a rotation matrix, and that you can make an antisymmetric matrix by choosing any old matrix , then doing . So we just figure out which  gets us an  which has the properties we like.

There is probably a far, far better way of solving this, other than gradient descent. If you are interested in the specifics, you may know a better way. Please, please tell me a better way!

  1. ^

    An example of a short cut: I really don't want to find a rotation which minimizes average sparsity across every input directly. This sounds very computationally expensive! Does minimizing my sparsity metric on a particular input, or only a few inputs generalize to minimizing the sparsity metric on many inputs?

  2. ^

    Meaning its a symmetric matrix with its top right half the opposite sign as it's bottom left half.

I'll put a $100 bounty on a better way that either saves Garrett at least 5 hours of research time, or is qualitatively better such that he settles on it.

What's the motivation for that specific sparsity prior/regularizer? Seems interestingly different than standard Ln.

Empirically, it works better than all the Ln norms for getting me zeros. Theoretically, it really likes zeros, whereas lots of other norms just like low numbers which are different things when talking about sparsity. I want zeros. I don't just want low numbers.

Work I'm doing at redwood involves doing somewhat similar things.

Some observations which you plausibly are already aware of:

  • You could use geotorch for the parametrization. geotorch has now been 'upstreamed' into pytorch as well
  • It's also possible to use use the from the decomposition to accomplish this. This has some advantages for me (specifically, you can orthogonalize arbitrary unfolded tensors which are parameterized in factored form), however, I believe the gradients via SGD will be less nice when using .
  • Naively, there probably isn't a better way to learn than via gradient descent (possible with better initialization etc.). This is 'just some random non-convex optimization problem', so what could you hope for? If you minimize sparsity on a single input as opposed to on average, then it seems plausible to me that you could pick a sparsity criteria such that the problem can be optimized in a nicer way (but I'd also expect that minimizing sparsity on a single input isn't really what you want).

You could hope for more even for a random non-convex optimization problem if you can set up a tight relaxation. E.g. this paper gives you optimality bounds via a semidefinite relaxation, though I am not sure if it would scale to the size of problems relevant here.

Interesting  decomposition idea. I'm going to try using the  as the initialization point of the rotation matrix, and see if this has any effect.

Interesting idea! 

What do you think about the Superposition Hypothesis? If that were true, then at a sufficient sparsity of features in the input there is no basis in which the network is thinking in, meaning it will be impossible to find a rotation matrix that allows for a bijective mapping between neurons and features.

I would assume that the rotation matrix that enables local changes via the sparse Jacobian coincides with one which maximizes some notion of "neuron-feature-bijectiveness". But as noted above that seems impossible if the SPH holds.

What do you think about the Superposition Hypothesis? If that were true, then at a sufficient sparsity of features in the input there is no basis in which the network is thinking in, meaning it will be impossible to find a rotation matrix that allows for a bijective mapping between neurons and features.

I'd say that there is a basis the network is thinking in in this hypothetical, it would just so happens to not match the human abstraction set for thinking about the problem in question.

If due to superposition, it proves advantageous to the AI to have a single feature that kind of does dog-head-detection and kind of does car-front-detection, because dog heads and car fronts don't show up in the training data at the same time, so it can still get perfect loss through a properly constructed dual-purpose feature like this, it'd mean that to the AI, dog heads and car fronts are "the same thing". 

The network hasn't figured out how to distinguish between them. In a more general data set where dog heads and car fronts can co-occur, this network would fail. Its abstractions are optimised for the narrow training data set, where it genuinely proved to be unnecessarily cumbersome to assign different concepts to those two things.

As AIs get more capable and general, I'd expect the concepts/features they use to start more closely matching the ones humans use in many domains. As AI gets superhuman, I would be somewhat worried about it finding new concept/feature sets that work even better and more generally than human ones.

If due to superposition, it proves advantageous to the AI to have a single feature that kind of does dog-head-detection and kind of does car-front-detection, because dog heads and car fronts don't show up in the training data at the same time, so it can still get perfect loss through a properly constructed dual-purpose feature like this, it'd mean that to the AI, dog heads and car fronts are "the same thing".

I don't think that's true. Imagine a toy scenario of two features that run through a 1D non-linear bottleneck before being reconstructed. Assuming that with some weight settings you can get superposition, the model is able to reconstruct the features ≈perfectly as long as they don't appear together. That means the model can still differentiate the two features, they are different in the model's ontology.

As AIs get more capable and general, I'd expect the concepts/features they use to start more closely matching the ones humans use in many domains.

My intuition disagrees here too. Whether we will observe superposition is a function of (number of "useful" features in the data), (sparsity of said features), and something like (bottleneck size). It's possible that bottleneck size will never be enough to compensate for number of features. Also it seems reasonable to me that ≈all of reality is extremely sparse in features, which presumably favors superposition.

Also it seems reasonable to me that ≈all of reality is extremely sparse in features, which presumably favors superposition.

Reality is usually sparse in features, and that‘s why even very small and simple intelligences can operate within it most of the time, so long as they don’t leave their narrow contexts. But the mark of a general intelligence is that it can operate even in highly out-of-distribution situations. Cars are usually driven on roads, so an intelligence could get by using a car even if its concepts of car-ness were all mixed up with its conception of roadness. But a human can plan to take a car to the moon and drive it on the dust there, and then do that. This indicates to me that a general intelligence needs to think in features that can compose to handle almost any data, not just data that usually appeared in the training distribution.

If your architectures has too many bottlenecks to allow this, I expect that it will not be able to become a human-level general intelligence.

(Parts of the human brain definitely seem narrow and specialised too of course, it‘s only the general reasoning capabilities that seem to have these ultra-factorising, nigh-universally applicable concepts.)

Note also that concepts humans use can totally be written as superpositions of other concepts too, most of these other concepts apparently just aren‘t very universally useful.

[-]TAG10

Reality is usually sparse in features, and that‘s why even very small and simple intelligences can operate within it most of the time, so long as they don’t leave their narrow contexts.

Reality is rich in features, but sparse in features that matter to a simple organism. That's why context matters.

I don't think that's true. Imagine a toy scenario of two features that run through a 1D non-linear bottleneck before being reconstructed. Assuming that with some weight settings you can get superposition, the model is able to reconstruct the features ≈perfectly as long as they don't appear together. That means the model can still differentiate the two features, they are different in the model's ontology.

I'm not sure I understand this example. If I have a single 1-D feature, a floating point number that goes up with the amount of dog-headedness or car-frontness in a picture, then how can the model in a later layer reconstruct whether there was a dog-head xor a car-front in the image from that floating point number, unless it has other features that effectively contain this information?

Possibly the source of our disagreement here is that you are imagining the neuron ought to be strictly monotonically increasing in activation relative to the dog-headedness of the image?

If we abandon that assumption then it is relatively clear how to encode two numbers in 1D. Let's assume we observe two numbers . With probability , and with probability 

We now want to encode these two events in some third variable , such that we can perfectly reconstruct  with probability .

I put the solution behind a spoiler for anyone wanting to try it on their own.

Choose some veeeery large  (much greater than the variance of the normal distribution of the features). For the first event, set . For the second event, set .

The decoding works as follows:

If  is negative, then with probability  we are in the first scenario and we can set . Vice versa if  is positive.

Ah, I see. Thank you for pointing this out. Do superposition features actually seem to work like this in practice in current networks? I was not aware of this.

In any case, for a network like the one you describe I would change my claim from

it'd mean that to the AI, dog heads and car fronts are "the same thing". 

to the AI having a concept for something humans don't have a neat short description for. So for example, if your algorithm maps X>0 Y>0 to the first case, I'd call it a feature of "presence of dog heads or car fronts, or presence of car fronts".

I don't think this is an inherent problem for the theory. That a single floating point number can contain a lot of information is fine, so long as you have some way to measure how much it is.  

Do superposition features actually seem to work like this in practice in current networks? I was not aware of this.

I'm not aware of any work that identifies superposition in exactly this way in NNs of practical use. 
As Spencer notes, you can verify that it does appear in certain toy settings though. Anthropic notes in their SoLU paper that they view their results as evidence for the SPH in LLMs. Imo the key part of the evidence here is that using a SoLU destroys performance but adding another LayerNorm afterwards solves that issue. The SoLU selects strongly against superposition and LayerNorm makes it possible again, which is some evidence that the way the LLM got to its performance was via superposition.

 

ETA: Ofc there could be some other mediating factor, too.

This example is meant to only illustrate how one could achieve this encoding. It's not how an actual autoencoder would work. An actual NN might not even use superposition for the data I described and it might need some other setup to elicit this behavior.
But to me it sounded like you are sceptical that superposition is nothing but the network being confused whereas I think it can be the correct way to still be able to reconstruct the features to a reasonable degree.

Not confused, just optimised to handle data of the kind seen in training, and with limited ability to generalise beyond that, compared to human vision.

Yeah I agree with that. But there is also a sense in which some (many?) features will be inherently sparse.

  • A token is either the first one of multi-token word or it isn't.
  • A word is either a noun, a verb or something else.
  • A word belongs to language LANG and not to any other language/has other meanings in those languages.
  •  image can only contain so many objects which can only contain so many sub-aspects.

I don't know what it would mean to go "out of distribution" in any of these cases.

This means that any network that has an incentive to conserve parameter usage (however we want to define that), might want to use superposition.

I'd say that there is a basis the network is thinking in in this hypothetical, it would just so happens to not match the human abstraction set for thinking about the problem in question.

Well, yes but the number of basis elements that make that basis human interpretable could theoretically be exponential in the number of neurons.

Sure, but that's not a question I'm primarily interested in. I don't want the most interpretable basis, I want the basis that network itself uses for thinking. My goal is to find the elementary unit of neural networks, to build theorems and eventually a whole predictive theory of neural network computation and selection on top of. 

That this may possibly make current networks more human-interpretable even in the short run is just a neat side benefit to me.

Ah, I might have misunderstood your original point then, sorry! 

I'm not sure what you mean by "basis" then. How strictly are you using this term?

I imagine you are basically going down the "features as elementary unit" route proposed in Circuits (although you might not be pre-disposed to assume features are the elementary unit).Finding the set of features used by the network and figuring out how its using them in its computations does not 1-to-1 translate to "find the basis the network is thinking in" in my mind.

I imagine you are basically going down the "features as elementary unit" route proposed in Circuits (although you might not be pre-disposed to assume features are the elementary unit).Finding the set of features used by the network and figuring out how its using them in its computations does not 1-to-1 translate to "find the basis the network is thinking in" in my mind.

Fair enough, imprecise use of language. For some definitions of "thinking" I'd guess a small vision CNN isn't thinking anything.

I mostly expect networks at zero loss to not to be in a superposition, since we should expect those networks to be in a broad basin, meaning fairly few independent, orthogonal, features, so less room to implement two completely different functions. But we don't always find networks in broad basins, so we may see some networks in a superposition.

It would be interesting to study which training regimes and architectures most/least often produce easily-interpretable networks by this metric, and this may give some insight into when you see superposition.

In the cases where there is a nice basis this device finds, we may also expect it to disentangle any superpositions which exist, and for this superposition to be a combination of two fairly simple functions, requiring very few features, or interpreting the same features in different ways.

I disagree with your intuition that we should not expect networks at irreducible loss to not be in superposition.

The reason I brought this up is that there are, IMO, strong first-principle reasons for why SPH should be correct. Say there are two features, which have an independent probability of 0.05 to be present in a given data point, then it would be wasteful to allocate a full neuron to each of these features. The probability of both features being present at the same time is a mere 0.00025. If the superposition is implemented well you get basically two features for the price of one with an error rate of 0.025%. So if there is even a slight pressure towards compression, e.g. by having less available neurons than features, then superposition should be favored by the network.

Now does this toy scenario map to reality? I think it does, and in some sense it is even more favorable to SPH since often the presence of features will be anti-correlated. 

Ah, I think you're right here, though I don't think this means there's no room for improvement on the sparsity front. Do you know of any hand-constructed examples of a layer in superposition, for which we know the features of? I'd like to play around with one, and see if there's any robust way to disentangle it.

I agree that all is not lost wrt sparsity and if SPH turns out to be true it might help us disentangle the superimposed features to better understand what is going on. You could think of constructing an "expanded" view of a neural network. The expanded view would allocate one neuron per feature and thus has sparse activations for any given data point and would be easier to reason about. That seems impractical in reality, since the cost of constructing this view might in theory be exponential, as there are exponentially many "almost orthogonal" vectors for a given vector space dimension, as a function of the dimension.

I think my original comment was meant more as a caution against the specific approach of "find an interpretable basis in activation space", since that might be futile, rather than a caution against all attempts at finding a sparse representation of the computations that are happining within the network.

I don't think there is anything on that front other than the paragraphs in the SoLU paper. I alluded to a possible experiment for this on Twitter in response to that paper but haven't had the time to try it out myself: You could take a tiny autoencoder to reconstruct some artificially generated data where you vary attributes such as sparsity, ratio of input dimensions vs. bottleneck dimensions, etc. You could then look at the weight matrices of the autoencoder to figure out how it's embedding the features in the bottleneck and which settings lead to superposition, if any.

I'm not at liberty to share it directly but I am aware that Anthropic have a draft of small toy models with hand-coded synthetic data showing superposition very cleanly. They go as far as saying that searching for an interpretable basis may essentially be mistaken.
 

Interesting idea, and I'm generally very in favour of any efforts to find more understandable and meaningful "elementary units" of neural networks right now. I think this is currently the research question that most bottlenecks any efforts to get a deeper understanding of NN internals and NN selection, and I think those things are currently the biggest bottlenecks to any efforts at generating alignment strategies that might actually work. So we should be experimenting with lots of ideas for different NN "bases" to use and construct our theory of Deep Learning on top of, until we get a strong signal that we've found the right one.

Both bases that keep the layer structure the same, such as the one you propose here, or the one we're planning to investigate next, and bases that assume the layer structure doesn't quite match the way we should be thinking about the time ordering of computations in the network, and allow basis transformations that put activations that used to be in different layers into the same layer. 

If anyone is looking to come up with more promising ideas for basis transformations, some guiding heuristics to generate candidates might be: bases that seem to spontaneously show up when you're investigating the math behind some property of neural networks, bases that seem to make neural networks a lot more understandable to humans without requiring a lot of effort to compute, bases that come out of some theory or hypothesis of what neural networks are "really doing", bases that have less degrees of freedom than the neuron basis but still seem to accurately capture the behaviour of the network in many aspects of training and deployment both.

I agree entirely with this bottleneck analysis, and am also very excited about the work you're doing and have just posted.

Why did you decide to only use rotation matrices instead of any invertible matrix?  If you're trying to find a new basis to work in, wouldn't any invertible matrix work just as well?

[-]CRG30

This is a great approach imo. I've tried something similar in transformers using the singular vectors of the embedding matrix (the d_model x d_model matrix) to rotate the matrices connected to the residual stream. This seemed to induce sparsity in the weights close to the first layer with decreasing effect moving deeper into the model. Tried this with the clip VIT-B and GPT-J, with the effect being a lot weaker in GPT-J. Also, some of the singular vectors of the embeddings were easily interpretable, with the top component being related to raw token frequency and interesting directions in GPT-J, (religion - technology) (positive - negative valence), and the top components of CLIP being color and frequency filters.

This is interesting, as I've (preliminarily) found the opposite with my methods. In my MNIST model, the first and last layers can't really be optimized any more than they are for sparsity, but the middle layer undergoes a drastic change.

If i'm reading this jacobian correctly - it seems to have found a basis where most of the output is determined by < 20 neurons? What's the perf effect?

I don't understand. Doesn't that shift the problem from the weight matrix to the rotation matrix? Yes, you know how inputs corresponds to outputs, but now there is this new matrix between outputs and the inputs of the next layer and it creates the same non-locality. 

I'm sorry if this is stupid, my linear algebra course was a very long time ago.

You don't have to compute the rotation every time for the weight matrix.  You can compute it once. It's true that you have to actually rotate the input activations for every input but that's really trivial.

Think of the rotation as an adjustment to our original program's representation, which keeps the program the same, but hopefully makes it clearer what it's doing. It would probably be interesting to see what exactly the adjustment is doing, like how it's interesting to see how someone would go about translating assembly to a python script, but not particularly your highest priority when trying to figure out what the assembly program is doing after you have the corresponding python script.

Thanks! I think I understood.

I was going to ask the same thing. It may not be possible to create a simple vector representation of the circuit if the circuit must simulate a complex nonlinear system. Doesn't seem possible, if the inputs are embeddings or sensor data from real world causal dynamic systems like images and text; and the outputs are vector space representations of the meaning in the inputs (semantic segmentation, NLP, etc). If it were possible it's like saying there's a simple, consitent vector representation of all the common sense reasoning about a particular image or natural language text. And your rotated explainable embedding would be far more accurate and robust than the original spaghetti circuit/program. Some programs are irreducible.

I think the best you can do is something like Capsule Nets (Hinton) which are many of your rotations (just smaller, 4 d quaternions, I think) distributed throughout the circuit.

Interesting idea.

Obviously doing this instead with a permutation composed with its inverse would do nothing but shuffle the order and not help.

You can easily do the same with any affine transformation, no? Skew, translation (scale doesn't matter for interpretability).

More generally if you were to consider all equivalent networks, tautologically one of them is indeed more input activation => output interpretable by whatever metric you define (input is a pixel in this case?).

It's hard for me to believe that rotations alone are likely to give much improvement.  Yes, you'll find a rotation that's "better".

What would suffice as convincing proof that this is valuable for a task: the transformation increases the effectiveness of the best training methods.

I would try at least fine-tuning on the modified network.

I believe people commonly try to train not a sequence of equivalent power networks (w/ a method to project from weights of the previous architecture to the new one), but rather a series of increasingly detailed ones.

Anyway, good presentation of an easy to visualize "why not try it" idea.

You can easily do the same with any affine transformation, no? Skew, translation (scale doesn't matter for interpretability).

You can do this with any normalized, nonzero, invertible affine transformation. Otherwise, you either get the 0 function, get a function arbitrarily close to zero, or are unable to invert the function. I may end up doing this.

What would suffice as convincing proof that this is valuable for a task: the transformation increases the effectiveness of the best training methods.

This will not provide any improvement in training, for various reasons, but mainly because I anticipate there's a reason the network is not in the interpretable basis. Interpretable networks do not actually increase training effectiveness. The real test of this method will be in my attempts to use it to understand what my MNIST network is doing.

Wouldn't your explainable rotated representation create a more robust model? Kind of like Newton's model of gravity was a better model than Kepler and Copernicus computing nested ellipses. Your model might be immune to adversarial examples and might generalize outside of the training set.