Jacob G-W

I really like learning new things!

https://jacobgw.com/

Wiki Contributions

Comments

Sorted by

So far, I have trouble because it lacks some form of spacial structure, and the algorithm feels too random to build meaningful connections btw different cards

Hmm, I think that after doing a lot of anki, my brain kind of formed it's own spatial structure, but I don't think this happens to everyone.

I just use basic card type for math with some latex. Here are some examples: 

I find that doing fancy card types is kind of like premature optimization. Doing the reviews is the most important part. On the other hand, it's really important that the cards themselves are written well. This essay contained my most refined views on card creation. Some other nice ones are the 20 rules of knowledge formulation and How to write good prompts: using spaced repetition to create understanding. Hope this answer helped!

I disagree that this is the same as just stitching together different autoencoders. Presumably the encoder has some shared computation before specializing at the encoding level. I also don't see how you could use 10 different autoencoders to classify an image from the encodings. I guess you could just look at the reconstruction loss and then the autoencoder which got the lowest loss would probably correspond to the label, but that seems different to what I'm doing. However, I agree that this application is not useful. I shared it because I (and others) thought it was cool. It's not really practical at all. Hope this addresses your question :)

I didn't impose any structure in the objective/loss function relating to the label. The loss function is just the regular VAE loss. All I did was detach the gradients in some places. So it is a bit surprising to me that this simple of a modification can cause the internals to specialize in this way. After I had seen gradient routing work in other experiments, I predicted that it would work here, but I don't think gradient routing working was a priori obvious (meaning that I would get zero new information by running an experiment since I predicted it with p=1).

Due to someone's suggestion, I've turned this into a top level post.

Jacob G-W2810

Over the past few months, I helped develop Gradient Routing, a non loss-based method to shape the internals of neural networks. After my team developed it, I realized that I could use the method to do something that I have long wanted to do: make an autoencoder with an extremely interpretable latent space.

I created an MNIST autoencoder with a 10 dimensional latent space, with each dimension of the latent space corresponding to a different digit. Before I get into how I did it, feel free to play around with my demo here (it loads the model into the browser): https://jacobgw.com/gradient-routed-vae/.

In the demo, you can both see how a random MNIST image encodes but also directly play around with the encoding itself and create different types of digits by just moving the sliders.

The reconstruction is not that good, and I assume this is due to some combination of (1) using the simplest possible architecture of MLP layers and ReLU (2) only allowing a 10 dimensional latent space which could constrain the representation a lot (3) not doing data augmentation, so it might not generalize that well, and (4) gradient routing targeting an unnatural internal representation, causing the autoencoder to not fit the data that well. This was just supposed to be a fun proof of concept project, so I’m not too worried about the reconstruction not being that good.

How it works

My implementation of gradient routing is super simple and easy to add onto a variational autoencoder. During training, after I run the encoder, I just detach every dimension of the encoding except for the one corresponding to the label of the image:

def encode_and_mask(self, images: Tensor, labels: Tensor):
    encoded_unmasked, zeta, mean_from_encoded, cov_diag_from_encoded = self.encode(images)
    mask_one_hot = F.one_hot(labels, num_classes=self.latent_size).float()
    encoded = mask_one_hot * encoded_unmasked + (1 - mask_one_hot) * encoded_unmasked.detach()
    return encoded, zeta, mean_from_encoded, cov_diag_from_encoded

This causes each dimension of the latent space to “specialize” to representing its corresponding image since the error for that image type can only be propagated through the single dimension of the latent space.

It turns out that if you do this, nothing forces the model to represent “more of a digit” in the positive direction. Sometimes the model represented “5-ness” in the negative direction in the latent space (e.g. as [0, 0, 0, 0, 0, -1.0, 0, 0, 0, 0]This messed with my demo a bit since I wanted all the sliders to only go in the positive direction. My solution? Just apply ReLU the encoding so it can only represent positive numbers! This is obviously not practical and I only included it so the demo would look nice.[1]

In our Gradient Routing paper, we found that models sometimes needed regularization to split the representations well. However, in this setting, I’m not applying any regularization besides the default regularization of the encoding that comes with a variational autoencoder. I guess it turns out that this regularization is enough to effectively split the digits.

Classification

It turns out that even though there was no loss function causing the encoding to activate most strongly on the dimension corresponding to the digit being encoded, it happened! In fact, we can classify digits to 92.58% accuracy by just taking the argmax over the encoding, which I find pretty amazing.

Code

You can see the code here.

(this was a crosspost of a post from my blog)

  1. ^

    I did have to train the model a few times to get something that behaved nicely enough for the demo.

Jacob G-WΩ461

Thanks for pointing this out! Our original motivation for doing it that way was that we thought of the fine-tuning on FineWeb-Edu as a "coherence" step designed to restore the model's performance after ablation, which damaged it a lot. We noticed that this "coherence" step helped validation loss on both forget and retain. However, your criticism is valid, so we have updated the paper so that we retrain on the training distribution (which contains some of the WMDP-bio forget set). We still see that while the loss on FineWeb-Edu decreases to almost its value before ablation, the loss on the WMDP-bio forget set is around 0.1 nats above its value before ablation, showing that it is harder to retrain virology after ablation than just FineWeb-Edu data. Since we re-train on the training distribution (N=12 times with different data), we would expect that both losses would be retrainable at roughly the same rate, but this is not the case, showing that localization and then ablation has an effect.

Nice work! A few questions:

I'm curious if you have found any multiplicity in the output directions (what you denote as ), or if the multiplicity is only in the input directions. I would predict that there would be some multiplicity in output directions, but much less than the multiplicity in input directions for the corresponding concept.

Relatedly, how do you think about output directions in general? Do you think they are just upweighting/downweighting tokens? I'd imagine that their level of abstraction depends on how far from the end of the network the output layer is, which will ultimately end up determining out much of their effect is directly on the unembed v.s. indirectly through other layers.

Something hashed with shasum -a 512 2d90350444efc7405d3c9b7b19ed5b831602d72b4d34f5e55f9c0cb4df9d022c9ae528e4d30993382818c185f38e1770d17709844f049c1c5d9df53bb64f758c

Isn't this a consequence of how the tokens get formed using byte pair encoding? It first constructs ' behavi' and then it constructs ' behavior' and then will always use the latter. But to get to the larger words, it first needs to create smaller tokens to form them out of (which may end up being irrelevant).

 

Edit: some experiments with the GPT-2 tokenizer reveal that this isn't a perfect explanation. For example "  behavio" is not a token. I'm not sure what is going on now. Maybe if a token shows up zero times, it cuts it?

Maybe you are right, since averaging and scaling does result in pretty good steering (especially for coding). See here.

Load More