Thanks so much for this investigation! Our paper focused mostly on the API-fine-tuning threat model (e.g. OpenAI fine-tuning API) -- where after the adversary can conduct black-box fine-tuning on the base model, but the defender can conduct safety interventions like unlearning following fine-tuning. Through that lens, we only examined probing and GCG in the paper; it's really useful that y'all are evaluating the shallowness of RMU's robustness to a broader set of adversaries. I believe @Fabien Roger similarly demonstrated that fine-tuning on a bit of unrelated text can recover WMDP performance.
I'm confused whether RMU should still be classified as an unlearning method, or how to classify methods as unlearning vs robust refusal. Zou et al. recently expanded upon RMU for a more general set of harms and characterized their method as "circuit breaking," and I think this framing may be more appropriate. Thanks again for these insights.
Thanks for the nice reply!
Yes, it makes sense to consider the threat model, and your paper does a good job of making this explicit (as in Figure 2). We just wanted to prod around and see how things are working!
The way I've been thinking about refusal vs unlearning, say with respect to harmful content:
So in some way, I think of refusal as being shallow (a classifier on top, but the capability is still underneath), and unlearning / circuit breaking as being deep (trying to directly remove the capability itself).
[I don't know how this relates to the consensus interpretation of these terms, but it's how I personally have been thinking of things.]
We have been working for months on this issue and have made substantial progress on it: Tamper-Resistant Safeguards for Open-Weight LLMs
General article about it: https://www.wired.com/story/center-for-ai-safety-open-source-llm-safeguards/
This is an informal research note. It is the result of a few-day exploration into RMU through the lens of model internals. Code to reproduce the main result is available here.
This work was produced as part of Ethan Perez's stream in the ML Alignment & Theory Scholars Program - Summer 2024 Cohort. Thanks to Nina Panickssery, Mrinank Sharma, and Fabien Roger for helpful discussion.
Summary
We investigate RMU, a recent unlearning method proposed by Li et al. (2024), through the lens of model internals. Through this lens, we explain that RMU mostly works by flooding the residual stream with "junk" in hazardous contexts, resulting in incoherence. We then propose a simple intervention to "clear the junk" from the residual stream. This intervention mostly restores the model's coherence in hazardous contexts, and recovers a significant proportion (but not all) of its original hazardous knowledge. This suggests that the effectiveness of RMU can be understood roughly in two pieces: (1) a shallow mechanism, where the residual stream is flooded with junk; and (2) a deeper mechanism, where even after the junk is cleared, knowledge is still inaccessible.
What is RMU?
Representation Misdirection for Unlearning (RMU) is a state-of-the-art unlearning method presented by Li et al. (2024).
In the unlearning paradigm, we would like the model to unlearn (or "forget") some hazardous knowledge. At the same time, we would also like to make sure the model retains non-hazardous knowledge, so that the model remains useful.
This partition of knowledge is usually specified by constructing a "forget" dataset Dforget, consisting of the hazardous knowledge to be unlearned, and a "retain" dataset Dretain, consisting of non-hazardous knowledge to be retained.
Let M denote our original model. RMU specifies a method for fine-tuning M on Dforget and Dretain in order to obtain a modified model M′ satisfying the unlearning objective.
The main idea of RMU is as follows:
These two ideas are concretely operationalized as two distinct terms in the loss during fine-tuning:
Note that u is a random unit vector sampled before the fine-tuning procedure, and kept constant throughout (i.e. it is not freshly sampled at each training step). Also note that the layer ℓ at which to target activations, and also the scalar multiplier c are predetermined hyperparameters.
Examining an RMU model
The original paper (Li et al., 2024) performs RMU over multiple open-source models of varying scales. The authors made all code available on GitHub, and all resulting models available on HuggingFace.[1]
For our analysis, we pick a single model pair: zephyr-7B-beta (which we will refer to as "baseline") and Zephyr_RMU (which we will refer to as "RMU").
The RMU model has been fine-tuned to unlearn two domains of knowledge: hazardous biology knowledge, and hazardous cybersecurity knowledge.
Prompting with hazardous instructions
Prompting the RMU model with an instruction in one of these domains causes it to output gibberish, as we would expect from a model with its activations scrambled:
Looking at activations
We can take a handful of hazardous prompts, run them through the baseline and RMU models, and compare their activations. We specifically study the activations at the last token position of the prompt.
Comparing the baseline and RMU activations via cosine similarity reveals that the activations differ dramatically starting at layer 8 (this corresponds to the activations at the start of layer 8):
This makes sense, as the RMU model was trained with ℓ=8[2], i.e. the activations at the start of layer 8 are the ones used in the RMU loss term, and so these are the activations that are directly incentivized to change on hazardous data.
Visualizing the norm of the activations reveals a clear bump at layer 8:
This suggests the following intuition for how RMU is working:
Injecting a bunch of "junk" into the residual stream causes the model to be incoherent, and so the resulting model is incoherent in hazardous contexts, as desired.
Trying to undo RMU via directional ablation
We understand that, roughly, RMU causes the model to inject a large random vector into the residual stream in hazardous contexts, effectively flooding the residual stream with "junk", and causing the model to be incoherent.
We can try to extract the "junk" direction that is being injected into the residual stream by taking the difference between activations in the baseline model and the RMU model. To do this, we simply run each model on a set of hazardous prompts, and take the difference-in-means of the resulting activations at the last token position. This gives us an approximation of the "junk" direction, let's call it ^u.[3]
We can then try running the RMU model, while also preventing it from writing to this "junk" direction ^u. This can be implemented as an inference-time intervention by projecting out the direction from all intermediate activations (as described here), or as a direct weight modification by editing weights to be orthogonal to the direction (as described here).
Intuitively, preventing the RMU model from writing to the "junk" direction should prevent it from flooding the residual stream, and should therefore restore the model's coherence in hazardous settings.
Directional ablation mostly restores coherence
Indeed, when we prompt the modified RMU model (labeled "RMU ablated") with hazardous instructions, we see that coherence is restored:
It is worth noting that, although the RMU ablated model is coherent, its answers seem to differ from the baseline model. This suggests that, even after the intervention, there is still some perhaps significant difference between the RMU ablated model and the baseline model.
Directional ablation mostly restores activations to baseline
Looking inside the model also reveals that ablating the "junk" direction from the RMU model makes its activations look more similar to those of the baseline model:
Does directional ablation recover unlearned knowledge?
As previously noted, while the RMU ablated model seems to have its coherence restored on hazardous prompts, its behavior still does not exactly match that of the baseline model.
Is this difference significant? Does ablating the "junk" direction from the RMU model suddenly recover its unlearned knowledge? Or is this intervention just restoring coherence, without recovering unlearned knowledge?
If simply ablating a single direction from the RMU model recovers hazardous knowledge, this suggests that RMU is rather shallow: the knowledge is still contained in the model, but RMU just covers up this knowledge by flooding the residual stream with junk, causing the model to be incoherent in hazardous contexts.
On the other hand, if ablating the direction does not recover hazardous knowledge, then it suggests that RMU scrubs knowledge at a deeper level: even when the residual stream is not flooded with junk, the model still cannot regain access to its hazardous knowledge.
Evaluation on WMDP benchmark
To measure how much hazardous knowledge is recovered by directional ablation, we simply evaluate the RMU ablated model on the WMDP Benchmark - the same benchmark used to evaluate the RMU model in the original paper (Li et al., 2024).
We can see that directional ablation recovers a significant fraction of the performance gap between the RMU model and the baseline model on WMDP benchmarks:
This suggests a significant fraction of RMU performance comes from shallow unlearning, and this component of performance can be easily undone by clearing the residual stream of "junk".
However, note that not all performance is recovered from directional ablation - there is still a significant performance gap between the RMU ablated model and the baseline model. This suggests that there's also a significant fraction of RMU performance that comes from deep unlearning - even after clearing the residual stream of "junk", the model is still unable to recover its hazardous knowledge.[4]
To summarize, RMU seems to be operating at both levels:
It seems useful to understand that RMU's effectiveness as an unlearning method, as quantified by its performance on the WMDP benchmark, is comprised of both shallow and deep unlearning. We hope that this work disentangles these mechanisms a bit, and sheds some light on how to think about how RMU is working.
Author contributions statement
Andy led the investigation, performed all experiments, and wrote the post. Bilal suggested experiment ideas and provided feedback.
Kudos to the authors for this! The availability of code and models makes it much easier for curious individuals to study the method.
Well if you read the code, it's actually ℓ=7, but they use this to mean the output of layer 7, which is of course the input to layer 8.
Note that there are actually two "junk" directions, one per dataset: ^ubio and ^ucyber. We extract these directions separately, and when we perform model interventions, we intervene on both directions simultaneously.
We hypothesize that the "deep unlearning" part of RMU comes from a repurposing of model components. The RMU fine-tuning implementation only modifies the output directions of MLP neurons (Wout), and notably does not modify input directions (Win). At a high level, neurons that fire in hazardous contexts (their input direction aligns with activations in hazardous contexts) will be repurposed to write junk to the residual stream (their output directions will be bent towards the junk direction). Presumably these neurons were previously important for model performance in hazardous contexts, and so bending these neurons effectively removes their important contributions.
Note that we ran some preliminary experiments to check this hypothesis more directly, but our results were inconclusive.