Taras Kutsyk

MATS scholar (Winter-2024) in the Neel Nanda & Athur Conmy cohort/Master's student at the University of L'Aquila.

Wiki Contributions

Comments

Sorted by

Thanks for the insight! I expect the same to hold though for Gemma 2B base (pre-trained) vs Gemma 2B Instruct models? Gemma-2b-Python-codes is just a full finetune on top of the Instruct model (probably produced without a large number of update steps), and previous work that studied Instruct models indicated that SAEs don't transfer to the Instruct Gemma 2B either.

Thanks! We'll take a closer look at these when we decide to extend our results for more models.

Let me make sure I understand your idea correctly:

  1. We use a separate single-layer model (analogous to the SAE encoder) to predict the SAE feature activations
  2. We train this model on the SAE activations of the finetuned model (assuming that the SAE wasn't finetuned on the finetuned model activations?)
  3. We then use this model to determine "what direction most closely maps to the activation pattern across input sequences, and how well it maps".

I'm most unsure about the 2nd step - how we train this feature-activation model. If we train it on the base SAE activations in the finetuned model, I'm afraid we'll just train it on extremely noisy data, because feature activations essentially do not mean the same thing, unless your SAE has been finetuned to appropriately reconstruct the finetuned model activations. (And if we finetune it, we might just as well use the SAE and feature-universality techniques I outlined without needing a separate model).

I like the idea of seeing if there are any features from the base model which are dead in the instruction-tuned-and-fine-tuned model as a proxy for "are there any features which fine-tuning causes the fine-tuned model to become unable to recognize"

Agreed, but I think our current setup is too limited to capture this. If we’re using the same “base SAE” for both the base and finetuned models, the situation like the one you described really implies “now this feature from the base model has a different vector (direction) in the activation space OR this feature is no longer recognizable”. Without training another SAE on the finetuned model, we have no way to tell the first case from the 2nd one (at least I don’t see it).

Another related question also strikes me as interesting, which is whether an SAE trained on the instruction-tuned model has any features which are dead in the base model…

This is indeed perhaps even more interesting, and I think the answer depends on how you map the features of the SAE trained on the instruction-tuned model to the features in the base model. If you do it naively by taking the feature (encoder) vector from the SAE trained on the instruction-tuned model (like it’s done in the post) and use it in the base model to check the feature’s activations, then once again you have a semi-decidable problem: either you get a high activation similarity (i.e. the feature has roughly the same activation pattern) indicating that it is present in the base model, OR you get something completely different: zero activation, ultralow density activation etc. And my intuition is that even “zero activation” case doesn’t imply that the feature “wasn’t there” in the base model, perhaps it just had a different direction!

So I think it’s very hard to make rigorous statements of which features “become dead” or “are born” by the finetuning process using only a single SAE. I imagine it would be possible to do using two different SAEs trained on the base and finetuned models separately, and then studying which features they learned are "the same features" using Anthropic’s feature university techniques (like the activation similarity or the logits similarity that we used). 

This would avoid uncertainties I mentioned by using model-independent feature representations like activation vectors. For example, if you find a feature X in the finetuned model (from the SAE trained on the finetuned model) for which there's no feature with high activation similarity in the base model (from the SAE trained on the base model), then it will indeed mean that there emerged a new feature in the finetuned model that isn’t recognized by the base model assuming that the SAE trained on the base model has recovered all the features recognized by the base model... The latter is a pretty strong assumption to make, but perhaps with the development of better SAE training techniques/architectures it will become less bold. 

P.S. now we've made the code repo public, thanks for pointing out!