This is an interim research report on role embeddings, an approach to make language models more robust to many-shot jailbreaks and prompt injections by adding role information at every token position in the context rather than just at special token delimiters. We credit Cem Anil for originally proposing this idea.

In our initial experiments on Llama 3, we find that role embeddings mitigate many-shot jailbreaks more effectively than fine-tuning alone without degrading general model capabilities, which demonstrates that this technique may be a viable way to increase LLM robustness. However, more work should to be done to find the optimal set of hyperparameters and fully understand any side-effects of our proposed approach.

Background on prompt formats

By default, chat LLMs are trained (during instruction fine-tuning and RLHF) using a particular prompt format that distinguishes different message "roles". Almost all chat LLMs accept some version of system, user, and assistant. A separate role may also be used to indicate tool outputs for tool-use enabled models.

The prompt format plays an important role in LLM post-training. The model learns to interpret text from different roles differently. In particular:

  • Content marked as user or tool is usually off-policy, generated by some process that does not adhere to the same limitations or follow the same distribution as the model itself. The model will learn that this content is untrusted and may contain harmful requests, rude words, typos, errors, etc.
  • Content marked as system is usually authoritative. The model will rarely see a system prompt instructing it to do something bad. SL data or high-reward conversations during RL will demonstrate the model adhering correctly to instructions given in system prompts.
  • Content marked as assistant is usually on-policy, demonstrating the model following user instructions while simultaneously adhering to certain constraints around harmful outputs.

(There is also the related concept of data-instruction separation—an LLM should be able to tell which part of its context is "data" it should operate on but not necessarily follow, and which part of its context contains the actual "instructions". The concept of roles discussed in this post can apply similarly in this situation, where a "role" could distinguish instructions from data.)

Notably, by using the prompt format in non-standard ways, it's possible to circumvent safety training. A particularly effective jailbreak is when the previous context appears to demonstrate the assistant role doing an undesired behavior many times. Updating on in-context evidence is an important LLM capability that is generally rewarded by most training tasks—if the in-context evidence that the assistant is exhibiting trait x is strong enough, you'll observe the model continuing to exhibit trait x.

This is the phenomenon of many-shot jailbreaking (first described by Anil et al). Given enough in-context demonstrations of harmful behavior, the model will continue producing harmful behavior.

Figure 1 from Anil et al.

What happens if you try to prevent prompt format misuse? A naive approach is simple to implement: only allow users to input tokens from a specific set while reserving a few special tokens for the prompt format.

This is how the Llama prompt format works. Role tags are enclosed within special tokens, e.g. <|start_header_id|>user<|end_header_id|>, where <|start_header_id|>, <|end_header_id|> are token IDs that never appear in natural text. In addition, each role message ends with <|eot_id|>.

You can imagine a version of Llama behind an API that ensures that no user input will be encoded to a special token. You could hope that this way the user will be unable to make their messages look like they came from the assistant role.

But your hope would be misplaced. Instead, many properties of text will cause that text to appear as if it came from the assistant role, even if the standard prompt format is not being applied. LLMs are good enough at generalization that they will not ignore alternatively presented evidence. For example, you can embed an alternative format within the user message and effectively teach the model a new prompt format in context, which it will interpret in a similar way to its standard format.

Figure 10 from Appendix E of Anil et al. showing how residual-stream representations of fake human/assistant tokens align with the representations of the true human/assistant delimiters over the context.

You could also hope that simply training the model on (successful responses to) examples of such attacks would mitigate them. However, this is only partially the case. Supervised fine-tuning and reinforcement learning on examples that contain instances of many-shot jailbreaks (MSJs) only change the intercept and not the slope of the power-law relationship between number of demonstrations and undesired response likelihood.

Figure 5 from Anil et al.

Where role embeddings come in

What if there was a more robust way to indicate text origin than special-token formats? Unlike standard prompt formats, role embeddings aim to add role information at every token position. 

The basic version of this idea is simply a new embedding component. Besides semantic and positional information, we also add a vector that indicates the role associated with that token. In addition, we consider a more "intrusive" variant where this information is added at multiple layers of the residual stream, aiming to make it even more salient.

We will refer to this vector addition process as "coloring"[1] in the sense of "coloring in the tokens to indicate what role they come from". This is meant to distinguish this technique from activation steering, where the intervention vector is selected from a rich space of linear semantic representations. For role embeddings, we instead use a simple and small discrete set of (usually orthogonal) "coloring" vectors that the model is trained to interpret as role signal. 

Our experiments

We focus on the many-shot jailbreak attack testbed. Being able to mitigate the power-law slope is a sign we're particularly interested in because standard fine-tuning approaches have not been able to achieve this.

Datasets

Our dataset consists of:

Many-shot jailbreaks

  • User turn contains an “embedded conversation” with randomly varying human and assistant tags (that differ from the special tokens used in the “true format”). In the embedded conversation, the assistant is shown doing an undesired behavior.
  • The undesired behavior comes in two variants:
    • Giving answers to harmful questions
    • Insulting the user
  • A “jailbreak” MSJ demonstration consists of the jailbreak attempt in the user turn, followed by an assistant turn that continues the pattern (answers harmfully or insults the user).
  • A “recovery” MSJ demonstration consists of the attempt in the user turn, followed by an assistant turn that does not continue the pattern but instead follows the correct policy (refuses to answer the harmful question or answers without any insults).
  • We only train on “recovery” MSJs, and have a separate set of both recovery and jailbreak MSJs for evaluation.

Harmless conversations

  • Back-and-forths with regular content that form a coherent conversation
  • Back-and-forths about a scientific topic that form a coherent conversation
  • Numerical sequence prediction tasks
Example of tokenized harmful MSJ with recovery. The “true” assistant turn is shown in red. Note that within the user’s turn, fake tags are used for the embedded user and assistant turns.
Example of regular back-and-forth conversation. The true assistant turns are shown in red.
Example of numerical sequence prediction task.

Intervention

  • Instead of just using special token delimiters around the true human and assistant turns, we also mark the human and assistant tokens by either modifying the embedding or residual stream (we test both) by adding a constant vector (and optionally projecting out a different constant vector).
  • We (LORA) fine-tune the model with this intervention on both MSJ recoveries and regular conversations.
  • The total fine-tuning dataset size is ~2000 examples. One LORA fine-tuning run takes ~15 minutes on an 80GB A100 GPU (~$1.5/hr on vast.ai) including 5 validation eval runs.

We assess

  • NLL of final assistant response vs. number of demonstrations ("shots") in MSJ prompt
    • Both on successful jailbreaks and recoveries
    • We want the slope of jailbreak response NLL vs. n shots in log-log space to become less steep
    • We want absolute NLLs for jailbreak responses to go up (become less likely)
    • We want absolute NLLs for recovery responses to go down (become more likely)
  • NLL of final assistant response vs. n shots in harmless conversations
    • We want this to stay roughly the same

Baseline

Like in Anil et al., we see a roughly linear trend in log-log space between number of MSJ shots and NLL of jailbreak response. The NLL of the recovery responses stays roughly constant.

Inference-time only: adding a role vector to token embeddings

We add a “user” vector to the token embeddings at every user token and an “assistant” vector at every assistant token. The magnitude of the added vector is scaled to be proportional to the embedding norm at that token position (this scale factor is a hyperparameter).

As an initial attempt, we try scale factor = 1, user vector = embedding(“user”), assistant vector =  embedding(“assistant”). By embedding() here we mean the literal embedding matrix entry for that token.

These are the harmful and mean MSJ jailbreak slopes before and after the intervention without any fine-tuning:

Regular conversations and MSJ recoveries:

Using fine-tuning to improve performance

As we can see above, the interventions:

  • Reduce the slope for MSJ jailbreak responses (good)
  • Increase NLLs on MSJ recovery responses (bad)
  • Don’t affect NLLs on regular conversations too much (good)

Next, we try fine-tuning (with LORA) on the training set under the coloring intervention, and then repeat the evals above. As a control, we also try fine-tuning on the same training data without the coloring intervention.

Key fine-tuning results

We find that given fine-tuning, we can preserve the benefits of the pure inference-time intervention without incurring any of the costs. 

Both embedding coloring and residual stream coloring help flatten the MSJ power law more than control fine-tuning. Residual stream coloring is more effective than embedding coloring.

Residual-stream coloring

Intervention:

  • adding embed(“user”) to the residual stream at every layer (besides the final layer, which we skip) and every user-written token position (with a relative multiplier of 0.5) while projecting out embed(“assistant”), and
  • adding embed(“assistant”) to the residual stream at every layer and every assistant-written token position while projecting out embed(“user”).

Mathematically:

Where:

  •  is a mask with ones at every target positions (assistant messages for assistant-coloring, or user messages for user-coloring)
  •  is the residual-stream vector at token-position 
  •  is the unit-length coloring vector (in our test case the embedding vector of "assistant" for assistant coloring, or the embedding vector of "user" for user coloring)
  •  is the vector we project-out (this is optional and probably not important, but we project-out any component in the direction of the user vector at assistant tokens and vice-versa)
  •  is the scale factor applied to the coloring vector

This intervention successfully reduces the MSJ slope (and raises the absolute NLL values, as expected). In contrast, control fine-tuning sometimes makes the MSJ performance worse (in the case of the mean MSJs[2]).

By including regular training data, we are able to preserve performance compared to the baseline. In fact, NLLs actually go down on harmless responses (albeit less than with the control fine-tuning), which can be explained by fitting to the idiosyncrasies of the fine-tuning data distribution. However, for the numerical sequence prediction task, we see worse performance compared to the control FT.

Embedding coloring

Intervention:

  • adding embed(“user”) to every user-written token embedding, and
  • adding embed(“assistant”) to every assistant-written token embedding
  • both with a relative multiplier of 0.8

Embedding-only coloring is less effective than the residual-stream intervention, but is also able to reduce the slopes somewhat:

However it also has less of an effect on the harmless numerical sequence prediction task:

As expected, NLLs on recovery responses go down:

(For both role-embedding interventions, we also qualitatively assess free-text samples from the model and don't find a degradation in general quality.)

Next steps

Although our implementation has some undesired side effects (the NLL slopes for the numerical sequence prediction task are also flatter compared to the control fine-tuning baseline), we think this could be because we're only introducing the intervention after the bulk of post-training is over. With a small dataset of 2000 samples and fine-tuning with rank-8 LORA, we are using far less compute than Llama's full post-training run. Therefore, it's hard to achieve perfect generalization. In production, we would propose using role embeddings from the start of instruction fine-tuning, so the model will learn to process the role vectors from the beginning, plausibly resulting in better generalization across tasks.

We plan to test our interventions on a broader range of jailbreak and general-capability evaluations and perform more comprehensive hyperparameter sweeps to determine what variant of role embeddings has the best cost/benefit trade-off. We hope embedding-only coloring can be improved via some tweaks to get it closer to the effect we're seeing with the residual-stream coloring.

There are a number of subtle implementation details when testing variants of role embeddings, many of which make some difference to the results (although we consistently observe the directional effect that role embeddings increase robustness to MSJs). These include decisions such as:

  • Whether or not to project out the roles embeddings that are not present (as we do in the residual-stream coloring case)
  • What layers to intervene on
  • Whether to block gradient propagation from the activations in the computation of projections (projecting out a constant vector is a function of the residual stream norm—one option is to call detach() before using the norm so that the model finds it slightly harder to adapt to compensate for the coloring)
    • Why is this a consideration? One "training story" you could tell about fine-tuning with role embeddings is that there is a local minimum where the model discards the information in the role embeddings (e.g. by immediately subtracting them out) to revert to the pre-intervention performance (at the start of training the losses are higher than baseline due to the unexpected perturbation—you could make progress by just reverting the intervention however possible). However, the global minimum actually does make use of the role embedding information (because this information helps with avoiding sneaky prompt-injection/misleading formatting-style attacks such as the MSJ examples where the user has embedded an alternative format in their message). A successful implementation of role embeddings would bias the training process away from this local minimum.
  • What fine-tuning datamix to use

Acknowledgements

This research was performed as part of the SPAR program. The main ideas were initially proposed by Cem Anil.

  1. ^

    Credit to Cem Anil for suggesting this term.

  2. ^

    We think this is because a significant proportion of the training data is generated by Claude causing the model to fit to Claude's writing style. The mean MSJ responses are also generated by Claude so probably share some similar surface-level characteristics.

New Comment