Lucius Bushnaq

AI notkilleveryoneism researcher, focused on interpretability. 

Personal account, opinions my own. 

I have signed no contracts or agreements whose existence I cannot mention.

Wiki Contributions

Comments

Sorted by

But through gradient descent, shards act upon the neural networks by leaving imprints of themselves, and these imprints have no reason to be concentrated in any one spot of the network (whether activation-space or weight-space).
 

What does 'one spot' mean here?

If you just mean 'a particular entry or set of entries of the weight vector in the standard basis the network is initalised in', then sure, I agree.

But that just means you have to figure out a different representation of the weights, one that carves the logic flow of the algorithm the network learned at its joints. Such a representation may not have much reason to line up well with any particular neurons, layers, attention heads or any other elements we use to talk about the architecture of the network. That doesn't mean it doesn't exist. 

Nice quick check!

Just to be clear: This is for the actual full models? Or for the 'model embeddings' as in you're doing a comparison right after the embedding layer?  

You could imagine a world where the model handles binding mostly via the token index and grammar rules. I.e. 'red cube, blue sphere' would have a 'red' feature at token ,  'cube' feature at token , 'blue' feature at , and 'sphere' feature at , with contributions like 'cube' at  being comparatively subdominant or even nonexistent.

I don't think I really believe this. But if you want to stick to a picture where features are directions, with no further structure of consequence in the activation space, you can do that, at least on paper.

Is this compatible with the actual evidence about activation structure we have? I don't know. I haven't come across any systematic investigations into this yet. But I'd guess probably not. 

Relevant. Section 3 is the one I found interesting.

If you wanted to check for matrix binding like this in real models, you could maybe do it by training an SAE with a restricted output matrix. Instead of each dictionary element being independent, you demand that  for your SAE can be written as , where . So, we demand that the second half of the SAE dictionary is just some linear transform of the first half.

That'd be the setup for pairs. Go  for three slots, and so on.

(To be clear, I'm also not that optimistic about this sort of sparse coding + matrix binding model for activation space. I've come to think that activations-first mech interp is probably the wrong way to approach things in general. But it'd still be a neat thing for someone to check.)



 

On its own, this'd be another metric that doesn't track the right scale as models become more powerful.

The same KL-div in GPT-2 and GPT-4 probably corresponds to the destruction of far more of the internal structure in the latter than the former.  

Destroy 95% of GPT-2's circuits, and the resulting output distribution may look quite different. Destroy 95% of GPT-4's circuits, and the resulting output distribution may not be all that different, since 5% of the circuits in GPT-4 might still be enough to get a lot of the most common token prediction cases roughly right.

I've seen a little bit of this, but nowhere near as much as I think the topic merits. I agree that systematic studies on where and how the reconstruction errors make their effects known might be quite informative.

Basically, whenever people train SAEs, or use some other approximate model decomposition that degrades performance, I think they should ideally spend some time after just playing with the degraded model and talking to it. Figure out in what ways it is worse.

The metric you mention here is probably 'loss recovered'. For a residual stream insertion, it goes

1-(CE loss with SAE- CE loss of original model)/(CE loss if the entire residual stream is ablated-CE loss of original model)

See e.g. equation 5 here.

So, it's a linear scale, and they're comparing the CE loss increase from inserting the SAE to the CE loss increase from just destroying the model and outputting a ≈ uniform distribution over tokens. The latter is a very large CE loss increase, so the denominator is really big. Thus, scoring over 90% is pretty easy. 
 

All current SAEs I'm aware of seem to score very badly on reconstructing the original model's activations.

If you insert a current SOTA SAE into a language model's residual stream, model performance on next token prediction will usually degrade down to what a model trained with less than a tenth or a hundredth of the original model's compute would get. (This is based on extrapolating with Chinchilla scaling curves at optimal compute). And that's for inserting one SAE at one layer. If you want to study circuits of SAE features, you'll have to insert SAEs in multiple layers at the same time, potentially further degrading performance.

I think many people outside of interp don't realize this. Part of the reason they don’t realize it might be that almost all SAE papers report loss reconstruction scores on a linear scale, rather than on a log scale or an LM scaling curve. Going from 1.5 CE loss to 2.0 CE loss is a lot worse than going from 4.5 CE to 5.0 CE. Under the hypothesis that the SAE is capturing some of the model's 'features' and failing to capture others, capturing only 50% or 10% of the features might still only drop the CE loss by a small fraction of a unit.

So, if someone is just glancing at the graphs without looking up what the metrics actually mean, they can be left with the impression that performance is much better than it actually is. The two most common metrics I see are raw CE scores of the model with the SAE inserted, and 'loss recovered'. I think both of these metrics give a wrong sense of scale. 'Loss recovered' is the worse offender, because it makes it outright impossible to tell how good the reconstruction really is without additional information. You need to know what the original model’s loss was and what zero baseline they used to do the conversion. Papers don't always report this, and the numbers can be cumbersome to find even when they do.

I don't know what an actually good way to measure model performance drop from SAE insertion is. The best I've got is to use scaling curves to guess how much compute you'd need to train a model that gets comparable loss, as suggested here. Or maybe alternatively, training with the same number of tokens as the original model, how many parameters you'd need to get comparable loss. Using this measure, the best reported reconstruction score I'm aware of is 0.1 of the original model's performance, reached by OpenAI's GPT-4 SAE with 16 million dictionary elements in this paper.

For most papers, I found it hard to convert their SAE reconstruction scores into this format. So I can't completely exclude the possibility that some other SAE scores much better. But at this point, I'd be quite surprised if anyone had managed so much as 0.5 performance recovered on any model that isn't so tiny and bad it barely has any performance to destroy in the first place. I'd guess most SAEs get something in the range 0.01-0.1 performance recovered or worse.

Note also that getting a good reconstruction score still doesn't necessarily mean the SAE is actually showing something real and useful. If you want perfect reconstruction, you can just use the standard basis of the network. The SAE would probably also need to be much sparser than the original model activations to provide meaningful insights.

Instrumentally, yes. The point is that I don’t really care terminally.

Getting the Hessian eigenvalues does not require calculating the full Hessian. You use Jacobian vector product methods in e.g. JAX. The Hessian itself never has to be explicitly represented in memory.

And even assuming the estimator for the Hessian pseudoinverse is cheap and precise, you'd still need to get its rank anyway, which would by default be just as expensive as getting the rank of the Hessian.

Why would we want or need to do this, instead of just calculating the top/bottom Hessian eigenvalues?

Load More