This is extremely useful for SAE circuit work. Now the connections between features are at most ReLU(Wx + b) which is quite interpretable! (Excluding attn_in->attn_out)
Thanks for doing this!
Another reason why layernorm is weird (and a shameless plug): the final layernorm also contributes to self-repair in language models
Here's a quick snipped to load the model into TransformerLens!
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=False, center_unembed=False).to("cpu")
# Kill the LayerNorms because TransformerLens overwrites eps
for block in hooked_model.blocks:
block.ln1.eps = 1e12
block.ln2.eps = 1e12
hooked_model.ln_final.eps = 1e12
# Make sure the outputs are the same
prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)
logits2 = model(prompt).logits
print(logits.shape, logits2.shape)
print(logits[0, 0, :10])
print(logits2[0, :10])
And here's the code to do it with replacing the LayerNorms with identities completely:
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
# Undo my hacky LayerNorm removal
for block in model.transformer.h:
block.ln_1.weight.data = block.ln_1.weight.data / 1e6
block.ln_1.eps = 1e-5
block.ln_2.weight.data = block.ln_2.weight.data / 1e6
block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5
# Properly replace LayerNorms by Identities
class HookedTransformerNoLN(HookedTransformer):
def removeLN(self):
for i in range(len(self.blocks)):
self.blocks[i].ln1 = torch.nn.Identity()
self.blocks[i].ln2 = torch.nn.Identity()
self.ln_final = torch.nn.Identity()
hooked_model = HookedTransformerNoLN.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
hooked_model.removeLN()
hooked_model.cfg.normalization_type = None
prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)
print(logits.shape)
print(logits[0, 0, :10])
And here's the code to convert it to NNsight (Thanks Caden for writing this awhile ago!)
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
from nnsight.models.UnifiedTransformer import UnifiedTransformer
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
# Undo my hacky LayerNorm removal
for block in model.transformer.h:
block.ln_1.weight.data = block.ln_1.weight.data / 1e6
block.ln_1.eps = 1e-5
block.ln_2.weight.data = block.ln_2.weight.data / 1e6
block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5
# Properly replace LayerNorms by Identities
def removeLN(transformer_lens_model):
for i in range(len(transformer_lens_model.blocks)):
transformer_lens_model.blocks[i].ln1 = torch.nn.Identity()
transformer_lens_model.blocks[i].ln2 = torch.nn.Identity()
transformer_lens_model.ln_final = torch.nn.Identity()
hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(hooked_model)
model_nnsight = UnifiedTransformer(model="gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(model_nnsight)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompt = torch.tensor([1,2,3,4], device=device)
logits = hooked_model(prompt)
with torch.no_grad(), model_nnsight.trace(prompt) as runner:
logits2 = model_nnsight.unembed.output.save()
logits, cache = hooked_model.run_with_cache(prompt)
torch.allclose(logits, logits2)
Fascinating. I would love to see follow up work on whether it does harm generalisation, because if we were able to train more interpretable models without damaging generalisation, that would be amazing.
I'd love to see other research along these lines. Like what if we could use interpretability to figure out what a circuit does, replace the circuit with something more principled/transparent, then train for a bit longer with the new circuit in place.
This is great! Maybe you'd get better results if you "distill" GPT2-LN into GPT2-noLN by fine-tuning on the entire token probability distribution on OpenWebText.
Interesting stuff! I'm very curious as to whether removing layer norm damages the model in some measurable way.
One thing that comes to mind is that previous work finds that the final LN is responsible for mediating 'confidence' through 'entropy neurons'; if you've trained sufficiently I would expect all of these neurons to not be present anymore, which then raises the question of whether the model still exhibits this kind of self-confidence-regulation
The LessWrong Review runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2025. The top fifty or so posts are featured prominently on the site throughout the year.
Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?
This work was produced at Apollo Research, based on initial research done at MATS.
Edit: arXiv version available at https://arxiv.org/abs/2409.13710
LayerNorm is annoying for mechanistic interpretability research (“[...] reason #78 for why interpretability researchers hate LayerNorm” – Anthropic, 2023).
Here’s a Hugging Face link to a GPT2-small model without any LayerNorm.
The final model is only slightly worse than a GPT2 with LayerNorm[1]:
I fine-tuned GPT2-small on OpenWebText while slowly removing its LayerNorm layers, waiting for the loss to go back down after reach removal:
Introduction
LayerNorm (LN) is a component in Transformer models that normalizes embedding vectors to have constant length; specifically it divides the embeddings by their standard deviation taken over the hidden dimension. It was originally introduced to stabilize and speed up training of models (as a replacement for batch normalization). It is active during training and inference.
LN(x)=x−E[x]√Var[x]+ϵ⋅γ+βThe equation includes the standard deviation (std) √Var[x]+ϵ which makes it a non-linear operation. This hinders interpretability in a variety of ways, from annoyances and inaccuracies such as
attributing residual stream directions to logit effects (e.g. SAE features, direct logit attribution),[2]
being annoying to deal with Attribution Patching, or
In the Docstring circuit analysis we seriously considered whether the model might be using LN in its algorithm. This post even shows that LN can be used as the sole non-linearity to solve non-linear classification problems (see also this related work).
Recently, with progress in Sparse Dictionary Learning, agendas (e.g. this one) imagine decomposing networks into sets of sparsely connected components (SAEs, Transcoders, etc.). A core difficulty to “putting it all together” is that the interactions between different components often route through LayerNorm whose effect we do not understand.
Motivation
It would be pretty neat to have an LLM that still works (speaks English etc.) while less or no LN layers. One option would be to train a model without LN from scratch (done for tiny models, e.g. TinyModel), but this is very hard or impossible for larger models (hearsay is that you need a low learning rate and to be very careful).
Taking an existing model and removing the LN layers however seems doable if LN isn’t implementing some important computation.[3] That is, LN “does its thing” and the model has learned to “deal with it”, but it’s not irreplaceable. A reason to be optimistic is that the spread of standard deviations across different samples isn’t that large, so maybe replacing the LN-computed standard deviation with a fixed number might kinda work.
Method
I take GPT2-small, fine-tune it on OpenWebText, and remove LNs one-by-one while fine-tuning.
The only non-linear operation in a LN layer is the division by the standard deviation (std) of the embedding vectors; the remaining operations can be absorbed into later weight matrices (see the
fold_ln
option in TransformerLens; also discussed in this appendix). Thus I mainly focus on the std part here.My general strategy is to “remove” an LN layer (this makes the loss go up), and then to train the model for some time (on the original training data) until the loss is back near the baseline. For this “remove” step I do the following
EOT tokens: When the input is an end-of-text token I also use the position 0 average std (because norms on EOT tokens tend to be larger)[4]
Whenever I do the replacement the loss jumps up, from the baseline of 3.0 up to 3.5, sometimes even around 5.0. After 10-100 iterations (learning rate 6e-4 and batch size approx. 488 as recommended here) the loss typically goes down to between 3.0 and 3.1. However, if I’m not careful and change too much at once, the loss can jump very high (around 8.0), and in those cases it usually never recovers. Thus I want to avoid making too big of a change at once.
Here’s the recipe I empirically found to work. After every step, train for 50-200 iterations or until the loss is close to baseline.
ln_2
(the LN before the MLP block), replacing the std with an average value. I use the “Position 0” special treatment but not the “EOT tokens” special treatment forln_2
.ln_1
for the q and k vectors. As forln_2
, I only use the “Position 0” special treatment.ln_1
for the v vector. Use both the “Position 0” and “EOT tokens” special treatments.ln_f
, the final LN. Again I only use the “Position 0” special treatment.I considered scaling individual LNs down slowly (e.g. interpolate between the actual calculated std and the average std) but I never ended up needing this, and did not really explore it.
In general I observed that
ln_f
at the end, it causes only a small increase in loss. If I remove it at the beginning it causes a much larger loss bump.ln_2
was usually easier than removingln_1
. This, combined with the point above, is why I removeln_2
beforeln_1
.ln_1
for the q and k vectors was not too bad, but the v vector was much harder. I only managed to remove its LN once I implemented the “EOT tokens” special treatment.ln_1
, while in other experiments (like the screenshot at the top) the spike happens then I remove the “EOT tokens” special treatment.Implementation
I implement everything based on the NanoGPT repository. I replace the standard deviation calculation in the LN by a fixed number (set to the average standard deviation). This number is fixed, but it is degenerate with the LN scale (self.weight) which is learnable.
I have two different average stds,
self.average_std
(average of std over all tokens except for position 0) andself.bos_std
(average of stds at position 0). Initially, after replacing the real LN with this “dummy LN” I use the following policy for choosing which std to divide by:self.average_std
if token position is > 0 [and token is not end-of-text (EOT)]self.bos_std
if token position is 0 [or current token us EOT]The EOT rule in [brackets] is used only for the LN before the attention v vector. Pseudocode for a simplified version of my implementation:
I will upload the full code when I have more time next week (email me if you'd like it earlier).
Results
Well, the models train, the loss is low, and the models still speak English.
v1 model: Manually interrupt & resume training whenever loss went down enough:
v2 model (+ vanilla GPT2 fine-tuning comparison in green) with scheduled LN removals:
I'm currently training a third version employing the gradual removal technique to more of the components, and will update the Hugging Face repository (use the v1 and v2 revisions for models presented here).
Here I compare the cross entropy losses on OpenWebText and ThePile, as well as the score on the HellaSwap benchmark (using the Andrej Karpathy's implementation). I choose two baselines; the original GPT2 model straight from Hugging Face, and a second version that I fine-tuned on OpenWebText for 1000 iterations to be more comparable to the noLN ones (which were fine-tuned on that data).
Some sample generations with
model.generate
:GPT2:
GP2_noLN:
Residual stream norms
Previously I noticed that the norm of the residual stream looks very different at position 0 compared to the rest of the model. Does this behavior still exists in the noLN model? No. The noLN model (solid lines) does not show the different in norms we saw in the original models (dashed lines):
Discussion
Faithfulness to the original model
I expect that the new no-LayerNorm (noLN) model will not have the exact same internal mechanisms as the original model. To some extent I expect lots of similarities (as the new model just had an hour to train), but since the norms changed (see above) and the loss changed, I expect differences.
My goal with this model is more like “have a toy model almost as good as GPT2 but easier to interpret” to replace vanilla GPT2 in interpretability research. Models like GPT2 (and Pythia etc.) are useful not because they are the models we ultimately care about (GPT4, GPT5, …) but because they let us generally explore how LLM internals work. To this extent, I don’t mind if GPT2-noLN differs from GPT2.
A question for the future is whether we want to apply this LN removal method before interpreting gpt4 and other “production models”. This depends on how similar the internals are, and I am currently uncertain about this. I am primarily concerned with the earlier use-case.
Does the noLN model generalize worse?
I noticed that the noLN performance hit was worse on The Pile than OpenWebText. This might be a coincidence, but it could also suggest that removing LN hurts generalization. While LN was originally introduced for training stability purposes, it may have a side effect on generalization. I have not evaluated the models on more datasets and leave this question for future research. Edit 30th Aug 2024: On the other hand this makes the OpenWebText and ThePile losses of the no-LN model more similar than those of the original model.
Appendix
Representing the no-LayerNorm model in GPT2LMHeadModel
I replace LayerNorms with DummyLayerNorms that use a fixed std, rather than computing the actual variance of each sample. This is equivalent to removing because the remaining LN operations can be folded into the following layers (e.g. TransformerLens mostly does this; TL does not fold in the centering operations but I do). I perform this folding in for all ln_1 and ln_2 layers. Thus I obtain weights with which GPT2 can run without any LayerNorm layers (all LNs replaced by nn.identity).
To make the model available on Hugging Face without trust_remote_code I want to package it into the GPT2LMHeadModel class. Thus I want to “neuter” the LNs in GPT2LMHeadModel such that the model just works with my “LNs are identities”-weights. I do this by setting ln_eps (epsilon) to a very high value (1e12), and setting the ln weights (gamma) to a corresponding value (1e6). I set the biases to 0. This leaves the centering operation but this doesn’t matter as I also fold a centering operation into the following layers, thus the LNs can be removed without further changes.
There’s one exception to this, the final layer norm. GPT2LMHeadModel uses (a) tied embedding and unembedding weights, and (b) no unembedding bias. Thus it is impossible to fold the final LayerNorm, which includes a (diagonal) weight matrix and a bias, into the other weights here. I still “neuter” the normalizing function of the LayerNorm as above, so ln_final just represents a simple linear layer before the unembedding.
Which order to remove LayerNorms in
There’s two sources of reasons that inform which LayerNorms I want to remove furst
So I expect which order to remove the LNs in matters.
A meta-choice is whether
I went with the first option for ease of implementation, but have not tried the second option. However the second option would seem more principled to me once we understand which order is optimal.
Which kinds of LayerNorms to remove first
I tried out a few combinations, such as first removing ln_f and then removing ln_2 and ln_1, or vice versa. I haven’t done a systematic sweep of all options, and the current method is just what felt right after a couple of tries. It seems to work well enough though.
Here’s an example of removing ln_f first. The loss reaches a very high level, and even after 400 iterations only goes down to 3.138. So it seems this is a worse choice. Not however that in this run I didn’t do “warm up” iterations (training for a couple iterations with LN to reach a good loss on OpenWebText – gpt2 directly loaded from Hugging Face does badly for the first ~10 iterations).
Which layer to remove LayerNorms in first
Here I remove ln2 in the different layers in different orders:
I start removing LNs at iteration 300, and remove another LN every 10 iterations.
The loss differs during the process (this is expected, some LNs are possibly more important than others) but evens out at the end.
Data-reuse and seeds
In my initial tests I used lots of snapshots, and accidentally retrained the model on the same first couple of batches of openwebtext (fixed seed) every time. I have the impression that this worked slightly better than my later full-pipeline runs never re-using data. I might investigate this in the future.
Infohazards
I am not worried that publishing this work accelerates capabilities progress over alignment progress. Because (a) this is a pretty obvious idea, (b) it applies only to inference (not training), and (c) it only speeds up inference by a very small amount (that likely is not even worth the loss increase).
Acknowledgements
Thanks to Alice Rigg, Bilal Chughtai, Leo Gao, Neel Nanda, and Rudolf Laine for comments and feedback on the draft. The nanoGPT repository and accompanying video by Andrej Karpathy were very helpful, allowing me to get a working prototype in a day!
The GPT2 paper claims a loss of log(16)=2.77 on their training dataset (non-public webtext). I guess that must be an easier dataset. In any case, I fine-tune both models on OpenWebText for a total of 1000 iterations (~500k rows, ~500M tokens) to give a fairer comparison.
For this case, only the final layer norm matters
this paper discusses confidence regularization as one possible-important aspect
I wonder how much of this effect is “just divide by a larger number” vs. actually dividing by the correct average. After all, the position 0 average shouldn’t be a great match for the EOT token average. [In this dataset the first position is not an EOT token. Neel Nanda / TransformerLebs recommends this for short prompts (see here for a discussion) but we don’t do it for the full dataset.]