All of Demian Till's Comments + Replies

Regarding some features not being learnt at all, I was anticipating this might happen when some features activate much more rarely than others, potentially incentivising SAEs to learn more common combinations instead of some of the rarer features. In order to potentially see this we'd need to experiment with more variations as mentioned in my other comment

Nice work! I was actually planning on doing something along these lines and still have some things I'd like to try.

Interestingly your SAEs appear to be generally failing to even find optimal solutions w.r.t the training objective. For example in your first experiment with perfectly correlated features I think the optimal solution in terms of reconstruction loss and L1 loss combined (regardless of the choice of the L1 loss weighting) would have the learnt feature directions (decoder weights) pointing perfectly diagonally. It looks like very few of your hype... (read more)

1Evan Anders
Hi Demian! Sorry for the really slow response. Yes! I agree that I was surprised that the decoder weights weren't pointing diagonally in the case where feature occurrences were perfectly correlated. I'm not sure I really grok why this is the case. The models do learn a feature basis that can describe any of the (four) data points that can be passed into the model, but it doesn't seem optimal either for L1 or MSE. And -- yeah, I think this is an extremely pathological case. Preliminary results look like larger dictionaries finding larger sets of features do a better job of not getting stuck in these weird local minima, and the possible number of interesting experiments here (varying frequency, varying SAE size, varying which things are correlated) is making for a pretty large exploration space.

Nice, that's promising! It would also be interesting to see how those peaks are affected when you retrain the SAE both on the same target model and on different target models.

Testing it with Pythia-70M and few enough features to permit the naive calculation sounds like a great approach to start with.

Closest neighbour rather than average over all sounds sensible. I'm not certain what you mean by unique vs non-unique. If you're referring to situations where there may be several equally close closest neighbours then I think we can just take the mean cos-sim of those neighbours, so they all impact on the loss but the magnitude of the loss stays within the normal range.

Only on features that activate also sounds sensible, but the dec... (read more)

Thanks for clarifying! Indeed the encoder weights here would be orthogonal. But I'm suggesting applying the orthogonality regularisation to the decoder weights which would not be orthogonal in this case.

2Logan Riggs
Ah, you're correct. Thanks!  I'm now very interested in implementing this method.

Thanks, I mentioned this as a potential way forward for tackling quadratic complexity in my edit at the end of the post.

Regarding achieving perfect reconstruction and perfect sparsity in the limit, I was also thinking along those lines i.e. in the limit you could have a single neuron in the sparse layer for every possible input direction. However please correct me if I’m wrong but assuming the SAE has only one hidden layer then I don't think you could prevent neurons from activating for nearby input directions (unless all input directions had equal magnitude), so you'd end up with many neurons activating for any given input and thus imperfect sparsity.

Otherwise mostly agree... (read more)

My bad! Yes since that's just one batch it does indeed come out as quadratic overall. I'll have a think about more efficient methods

This looks interesting. I'm having a difficult time understanding the results though. It would be great to see a more detailed write up!

yeah I was thinking abs(cos_sim(x,x'))

I'm not sure what you're getting at regarding the inhibitory weights as the image link is broken

2Logan Riggs
Thanks for saying the link is broken! If the True Features are located at: A: (0,1) B: (1,0) [So A^B: (1,1)] Given 3 SAE hidden-dimensions, a ReLU & bias, the model could learn 3 sparse features 1. A^~B (-1, 1) 2. A^B (1,1) 3. ~A^B(1,-1) that output 1-hot vectors for each feature. These are also are orthogonal to each other. Concretely: import torch W = torch.tensor([[-1, 1],[1,1],[1,-1]]) x = torch.tensor([[0,1], [1,1],[1,0]]) b = torch.tensor([0, -1, 0]) y = torch.nn.functional.relu(x@W.T + b)

If n is the number of feature we're trying to discover and m is the number of features in each batch, then I'm thinking the naive approach is O(n^2) while the batch approach would be O(m^2 + mn). Still quadratic in m, but we would have m<<n

2Chris_Leong
Isn’t that just one batch?

Even for a fairly small target model we might want to discover e.g. 100K features and and the input vectors might be e.g. 768D. That's a lot of work to compute that matrix!

6Charlie Steiner
Hm. Okay, I remembered a better way to improve efficiency: neighbor lists. For each feature, remember a list of who its closest neighbors are, and just compute your "closeness loss" by calculating dot products in that list. The neighbor list itself can either be recomputed once in a while using the naive method, or you can accelerate the neighbor list recomputation by keeping more coarse-grained track of where features are in activation-space.

Thanks! Yeah I think those steps make sense for the iterative process, but I'm not sure if you're proposing that would tackle the problem of feature combinations by itself? I'm still imagining it would require orthogonality regularisation with some weighting