I think this might suggest there is some fundamentally better way to do sampling from GPT models? I'm having trouble writing out the intuition clearly, so I'll leave it for later posts.
Unroll the sampling process: hook up all the individual GPT instances into a single long model, bypass the discretizing/embedding layers to make it differentiable end-to-end, and do gradient ascent to find the sequence which maximizes likelihood conditional on the fixed input.
Interesting, but not (I think?) the direction I was headed in.
I was thinking more about the way the model seems to be managing a tradeoff between preserving the representation of token i and producing the representation of token i+1.
The depth-wise continuity imposed by weight decay means late layers are representing something close to the final output -- in late layers the model is roughly looking at its own guesses, even if they were wrong, which seems suboptimal.
Consider this scenario:
My sampling idea was something like "let's replace (or interpolate) late activations with embeddings of the actual next token, so the model can see what really happened, even when its probability was low." (This is for sampling specifically because it'd be too slow in training, where you want to process a whole window at once with matrix operations; sampling has to be a loop anyway, so there's no cost to adding stuff that only works as a loop.)
But, thinking about it more, the model clearly can perform well in scenarios like the above, e.g. my plasma example and also many other cases naturally arising in language which GPT handles well.
I have no idea how it does it -- indeed the connection structure feels weirdly adverse to such operations -- but apparently it does. So it's probably premature to assume it can't do this well, and attempt to "help it out" with extra tricks.
It doesn't sound hard at all. The things Gwern is describing are the same sort of thing that people do for interpretability where they, eg, find an image that maximizes the probability of the network predicting a target class.
Of course, you need access to the model, so only OpenAI could do it for GPT-3 right now.
Doing it with GPT-3 would be quite challenging just for compute requirements like RAM. You'd want to test this out on GPT-2-117M first, definitely. If the approach works at all, it should work well for the smallest models too.
This is very neat. I definitely agree that I find the discontinuity from the first transformer block surprising. One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model. You could also see if the resulting matrix has any relationship to the embedding matrix (e.g. are the two matrices farther apart or closer together than would be expected by chance?). One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model's guess is given that input or whether it's just being stored in parts of the embedding space that aren't very relevant to the output (if it's the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model.
That's a great idea!
One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model's guess is given that input or whether it's just being stored in parts of the embedding space that aren't very relevant to the output (if it's the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
Hmm... I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they're near-impossible so you don't need to track them closely), and then smoothly subtracted out at the end. There's probably a way to check if that's happening.
That's a great idea!
Thanks! I'd be quite excited to know what you find if you end up trying it.
Hmm... I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they're near-impossible so you don't need to track them closely), and then smoothly subtracted out at the end. There's probably a way to check if that's happening.
I wasn't thinking you would do this with the natural component basis—though it's probably worth trying that also—but rather doing some sort of matrix decomposition on the embedding matrix to get a basis ordered by importance (e.g. using PCA or NMF—PCA is simpler though I know NMF is what OpenAI Clarity usually uses when they're trying to extract interpretable basis elements from neural network activations) and then seeing what the linear model looks like in that basis. You could even just do something like what you're saying and find some sort of basis ordered by the frequency of the tokens that each basis element corresponds to (though I'm not sure exactly what the right way would be to generate such a basis).
I also thought of PCA/SVD, but I imagine matrix decompositions like these would be misleading here.
What matters here (I think) is not some basis of N_emb orthogonal vectors in embedding space, but some much larger set of ~exp(N_emb) almost orthogonal vectors. We only have 1600 degrees of freedom to tune, but they're continuous degrees of freedom, and this lets us express >>1600 distinct vectors in vocab space as long as we accept some small amount of reconstruction error.
I expect GPT and many other neural models are effectively working in such space of nearly orthogonal vectors, and picking/combining elements of it. A decomposition into orthogonal vectors won't really illuminate this. I wish I knew more about this topic -- are there standard techniques?
You might want to look into NMF, which, unlike PCA/SVD, doesn't aim to create an orthogonal projection. It works well for interpretability because its components cannot cancel each other out, which makes its features more intuitive to reason about. I think it is essentially what you want, although I don't think it will allow you to find directly the 'larger set of almost orthogonal vectors' you're looking for.
Maybe I am misunderstanding something, but to me it is very intuitive that there is a big jump from the embedding output to the first transformer block output. The embedding is backpropagated into so it makes sense to see all representations as representations of the prediction we are trying to make, i.e. of the next word.
But the embedding is a prediction of the next word based on only a single word, the word that is being embedded. So the prediction of the next word is by necessity very bad (the BPE ensures that, IIUC, because tokens that would always follow one another are merged).
The first transformer block integrates hundreds of words of context into the prediction, that’s where the big jump comes from.
Is it really trained to output the input offset by one, or just to have the last slot contain the next word? Because I would expect it to be better at copying the input over by one...
If each layer were trained to give its best guess at the next token, this myopia would prevent all sorts of hiding data for later. This would be a good experiment for your last story, yes? I expect this would perform very poorly, though if it doesn't, hooray, for I really don't expect that version to develop inner optimizers.
I think I understand your question and was also confused by this for a bit so I wanted add in some points of clarification. First I want out that I really couldn't find a satisfactory explanation of this particular detail (at least one that I could understand) so I pieced this together myself from looking at the huggingface code for GPT2. I may get some details wrong.
During training at each step the GPT2 takes in an N tokens and outputs N tokens. But the i-th output token is computed in such away that it only relies on the information from tokens 1, ..., i and is meant to predict i+1-th token from these. I think it's best to think of each output being computed independently of the others (though this isn't strictly true since the separate outputs are computed by shared matrices). So for each i, we train the network so that the i-th output produces the correct result given the _input_ tokens 1, ..., i. There is a term in the loss function for each output token and the total loss is the sum of all the losses of the output tokens. The outputs at other positions do not play a role in the i-th output token, only the first 1,..., i input tokens do.
During inference, given an input of k tokens, we are only concerned with the k-th output token (which should predict the token following the first k). GPT-3 also produces predictions for the outputs before position k but these are just ignored since we already know what these values should be.
Is it really trained to output the input offset by one, or just to have the last slot contain the next word? Because I would expect it to be better at copying the input over by one...
Not sure I understand the distinction, could you rephrase?
If by "last slot" you mean last layer (as opposed to earlier layers), that seems like the same thing as outputting the input offset by one.
If by "last slot" you mean the token N+1 given tokens (1, 2, ... N), then no, that's not how GPT works. If you put in tokens (1, 2, ... N), you always get guesses for tokens (2, 3, ..., N+1) in response. This is true even if all you care about is the guess for N+1.
I meant your latter interpretation.
Can you measure the KL-divergence at each layer from the input, rather than the output? KL does not satisfy the triangle inequality, so maybe most of the layers are KL-close to both input and output?
GPT uses ReLU, yes? Then the regularization would make it calculate using small values, which would be possible because ReLU is nonlinear on small values. If we used an activation function that's linear on small values, I would therefore expect more of the calculation to be visible.
Can you measure the KL-divergence at each layer from the input, rather than the output? KL does not satisfy the triangle inequality, so maybe most of the layers are KL-close to both input and output?
One can do this in the Colab notebook by calling show_token_progress
with comparisons_vs="first"
rather than the default "final"
. IIRC, this also shows a discontinuous flip at the bottom followed by slower change.
(This is similar to asking the question "do the activations assign high or low probability the input token?" One can answer the same question by plotting logits or ranks with the input layer included.)
GPT uses ReLU, yes? Then the regularization would make it calculate using small values, which would be possible because ReLU is nonlinear on small values.
It uses gelu, but gelu has the same property. However, note that I am extracting activations right after the application of a layer norm operation, which shifts/scales the activations to mean 0 and L2 norm 1 before passing them to the next layer.
gelu has the same property
Actually, gelu is differentiable at 0, so it is linear on close-to-zero values.
Ah, I think we miscommunicated.
I meant "gelu(x) achieves its maximum curvature somewhere near x=0."
People often interpret relu as a piecewise linear version of functions like elu and gelu, which are curved near x=0 and linear for large |x|. In this sense gelu is like relu.
It sounds like you were, instead, talking about the property of relu that you can get nonlinear behavior for arbitrarily small inputs.
This is indeed unique to relu -- I remember some DeepMind (?) paper that used floating point underflow to simulate relu, and then made NNs out of just linear floating point ops. Obviously you can't simulate a differentiable function with that trick.
floating point underflow to simulate relu
Oh that's not good. Looks like we'd need a version of float that keeps track of an interval of possible floats (by the two floats at the end of the interval). Then we could simulate the behavior of infinite-precision floats so long as the network keeps the bounds tight, and we could train the network to keep the simulation in working order. Then we could see whether, in a network thus linear at small numbers, every visibly large effect has a visibly large cause.
By the way - have you seen what happens when you finetune GPT to reinforce this pattern that you're observing, that every entry of the table, not just the top right one, predicts an input token?
IIRC, this also shows a discontinuous flip at the bottom followed by slower change.
Maybe edit the post so you include this? I know I was wondering about this too.
Good idea, I'll do that.
I know I'd run those plots before, but running them again after writing the post felt like it resolved some of the mystery. If our comparison point is the input, rather than the output, the jump in KL/rank is still there but it's smaller.
Moreover, the rarer the input token is, the more it seems to be preserved in later layers (in the sense of low KL / low vocab rank). This may be how tokens like "plasma" are "kept around" for later use.
Apparently it is keeping around a representation of the token "plasma" with enough resolution to copy it . . . but it only retrieves this representation at the end! (In the rank view, the rank of plasma is quite low until the very end.)
This is surprising to me. The repetition is directly visible in the input: "when people say" is copied verbatim. If you just applied the rule "if input seems to be repeating, keep repeating it," you'd be good. Instead, the model scrambles away the pattern, then recovers it later through some other computational route.
One more reason on why this is suprising, is that other experiments found that this behaviour (forgetting then recalling) is common in MLM (masked language models) but not in simple language models like GPT-2 (see this blog post and more specifically this graph). The intepretation is that "for MLMs, representations initially acquire information about the context around the token, partially forgetting the token identity and producing a more generalized token representation; the token identity then gets recreated at the top layer" (citing from the blog post).
However, the logit lense here seems indicating that this may happen in GPT-2 (large) too. Could this be a virtue of scale? Where the same behaviour that one obtains with a MLM is reached by a LM as well with sufficient scale?
Cool project. There were some changes in HuggingFace's transformer package which are affecting you Colab implementation. See here:
https://github.com/huggingface/transformers/issues/29576
Could you try a prompt that tells it to end a sentence with a particular word, and see how that word casts its influence back over the sentence? I know that this works with GPT-3, but I didn't really understand how it could.
Interesting topic! I'm not confident this lens would reveal much about it (vs. attention maps or something), but it's worth a try.
I'd encourage you to try this yourself with the Colab notebook, since you presumably have more experience writing this kind of prompt than I do.
Hey I'm not finished reading this yet but I noticed something off about what you said.
At the end, the final 1600-dimensional vector is multiplied by W's transpose to project back into vocab space.
This isn't quite right. They don't multiply by W's transpose at the end. Rather there is a completely new matrix at the end, whose shape is the same as the transpose of W.
You can see this in huggingface's code for GPT2. In the class GPT2LMHeadModel the final matrix multiplication is performed by the matrix called "lm_head", where as the matrix you call W which is used to map 50,257 dimensional vectors into 1600 dimensional space is called "wte" (found in the GPT2Model class). You can see from the code that wte has shape "Vocab size x Embed Size" while lm_head has shape "Embed Size x Vocab size" so lm_head does have the same shape as W transpose but doesn't have the same numbers.
Edit: I could be wrong here, though. Maybe lm_head was set to be equal to wte transpose? I'm looking through the GPT-2 paper but don't see anything like that mentioned.
Maybe lm_head was set to be equal to wte transpose?
Yes, this is the case in GPT-2. Perhaps the huggingface implementation supports making these two matrices different, but they are the same in the official GPT-2.
Edit: I think the reason this is obscured in the huggingface implementation is that they always distinguish the internal layers of a transformer from the "head" used to convert the final layer outputs into predictions. The intent is easy swapping between different "heads" with the same "body" beneath.
This forces their code to allow for heads that differ from the input embedding matrix, even when they implement models like GPT-2 where the official specification says they are the same.
Edit2: might as well say explicitly that I find the OpenAI tensorflow code much more readable than the huggingface code. This isn't a critique of the latter; it's trying to support every transformer out there in a unified framework. But if you only care about GPT, this introduces a lot of distracting abstraction.
This post relates an observation I've made in my work with GPT-2, which I have not seen made elsewhere.
IMO, this observation sheds a good deal of light on how the GPT-2/3/etc models (hereafter just "GPT") work internally.
There is an accompanying Colab notebook which will let you interactively explore the phenomenon I describe here.
[Edit: updated with another section on comparing to the inputs, rather than the outputs. This arguably resolves some of my confusion at the end. Thanks to algon33 and Gurkenglas for relevant suggestions here.]
[Edit 5/17/21: I've recently written a new Colab notebook which extends this post in various ways:
]
overview
background on GPT's structure
You can skip or skim this if you already know it.
the logit lens
As described above, GPT schematically looks like
We have a "dictionary," W, that lets us convert between vocab space and embedding space at any point. We know that some vectors in embedding space make sense when converted into vocab space:
What about the 1600-dim vectors produced in the middle of the network, say the output of the 12th layer or the 33rd? If we convert them to vocab space, do the results make sense? The answer is yes.
logits
For example: the plots below show the logit lens on GPT-2 as it predicts a segment of the abstract of the GPT-3 paper. (This is a segment in the middle of the abstract; it can see all the preceding text, but I'm not visualizing the activations for it.)
For readability, I've made two plots showing two consecutive stretches of 10 tokens. Notes on how to read them:
There are various amusing and interesting things one can glimpse in these plots. The "early guesses" are generally wrong but often sensible enough in some way:
ranks
The view above focuses only on the top-1 guess at each layer, which is a reductive window on the full distributions.
Another way to look at things: we still reduces the final output to the top-1 guess, but we compare other distributions to the final one by looking at the rank of the final top-1 guess.
Even if the middle of the model hasn't yet converged to the final answer, maybe it's got that answer somewhere in its top 3, top 10, etc. That's a lot better than "top 50257."
Here's the same activations as ranks. (Remember: these are ranks of the model's final top-1 prediction, not the true token.)
In most cases, network's uncertainty has drastically reduced by the middle layers. The order of the top candidates may not be right, and the probabilities may not be perfectly calibrated, but it's got the gist already.
KL divergence and input discarding
Another way of comparing the similarity of two probability distributions is the KL divergence. Taking the KL divergence of the intermediate probabilities w/r/t the final probabilities, we get a more continuous view of how the distributions smoothly converge to the model's output.
Because KL divergence is a more holistic measure of the similarity between two distributions than the ones I've used above, it's also my preferred metric for making the point that nothing looks like the input.
In the plots above, I've skipped the input layer (i.e. the input tokens in embedding space). Why? Because they're so different from everything else, they distract the eye!
In the plots below, where color is KL divergence, I include the input as well. If we trust that KL divergence is a decent holistic way to compare two distributions (I've seen the same pattern with other metrics), then:
other examples
I show several other examples in the Colab notebook. I'll breeze through a few of them here.
copying a rare token
Sometimes it's clear that the next token should be a "copy" of an earlier token: whatever arbitrary thing was in that slot, spit it out again.
If this is a token with relatively low prior probability, one would think it would be useful to "keep it around" from the input so later positions can look at it and copy it. But as we saw, the input is never "kept around"!
What happens instead? I tried this text:
As shown below (truncated to the last few tokens for visibility), the model correctly predicts "plasma" at the last position, but only figures it out in the very last layers.
Apparently it is keeping around a representation of the token "plasma" with enough resolution to copy it . . . but it only retrieves this representation at the end! (In the rank view, the rank of plasma is quite low until the very end.)
This is surprising to me. The repetition is directly visible in the input: "when people say" is copied verbatim. If you just applied the rule "if input seems to be repeating, keep repeating it," you'd be good. Instead, the model scrambles away the pattern, then recovers it later through some other computational route.
extreme repetition
We've all seen GPT sampling get into a loop where text repeats itself exactly, over and over. When text is repeating like this, where is the pattern "noticed"?
At least in the following example, it's noticed in the upper half of the network, while the lower half can't see it even after several rounds of repetition.
why? / is this surprising?
First, some words about why this trick can even work at all.
One can imagine models that perform the exact same computation as GPT-2, for which this trick would not work. For instance, each layer could perform some arbitrary vector rotation of the previous one before doing anything else to it. This would preserve all the information, but the change of basis would prevent the vectors from making sense when multiplied by W^T.
Why doesn't the model do this? Two relevant facts:
1. Transformers are residual networks. Every connection in them looks like x + f(x) where f is the learned part. So the identity is very easy to learn.
This tends to keep things in the same basis across different layers, unless there's some reason to switch.
2. Transformers are usually trained with weight decay, which is almost the same thing as L2 regularization. This encourages learned weights to have small L2 norm.
That means the model will try to "spread out" a computation across as many layers as possible (since the sum-of-squares is less than the square-of-sums). Given the task of turning an input into an output, the model will generally prefer changing the input a little, then a little more, then a little more, bit by bit.
1+2 are a good story if you want to explain why the same vector basis is used across the network, and why things change smoothly. This story would render the whole thing unsurprising . . . except that the input is discarded in such a discontinuous way!
I would have expected a U-shaped pattern, where the early layers mostly look like the input, the late layers mostly look like the output, and there's a gradual "flip" in the middle between the two perspectives. Instead, the input space immediately vanishes, and we're in output space the whole way.
Maybe there is some math fact I'm missing here.
Or, maybe there's some sort of "hidden" invertible relationship between
so that a token like "plasma" is kept around from the input -- but not in the form "the output is plasma," instead in the form "the output is [the kind of word that comes after plasma]."
However, I'm not convinced by that story as stated. For one thing, GPT layers don't share their weights, so the mapping between these two spaces would have to be separately memorized by each layer, which seems costly. Additionally, if this were true, we'd expect the very early activations to look like naive context-less guesses for the next token. Often they are, but just as often they're weird nonsense like "Garland."
addendum: more on "input discarding"
In comments, Gurkenglas noted that the plots showing KL(final || layer) don't tend the whole story.
The KL divergence is not a metric: it is not symmetric and does not obey the triangle inequality. Hence my intuitive picture of the distribution "jumping" from the input to the first layer, then smoothly converging to the final layer, is misleading: it implies we are measuring distances along a path through some space, but KL divergence does not measure distance in any space.
Gurkenglas and algon33 suggested plotting the KL divergences of everything w/r/t the input rather than the output: KL(input || layer).
Note that the input is close to a distribution that just assigns probability 1 to the input token ("close" because W * W^T is not invertible), so this is similar to asking "how probable is the input token, according to each layer?" That's a question which is also natural to answer by plotting ranks: what rank is assigned to the input token by each layer?
Below, I show both: KL(input || layer), and the rank of the input token according to later layers.
It's possible that the relatively high ranks -- in the 100s or 1000s, but not the 10000s -- of input tokens in many cases is (related to) the mechanism by which the model "keeps around" rarer tokens in order to copy them later.
As some evidence for this, I will show plots like the above for the plasma example. Here, I show a segment including the first instance of "plasma," rather than the second which copies it.
The preservation of "plasma" here is striking.
My intuitive guess is that the rarity, or (in some sense) "surprisingness," of the token causes early layers to preserve it: this would provide a mechanism for providing raw access to rare tokens in the later layers, which otherwise only be looking at more plausible tokens that GPT had guessed for the corresponding positions.
On the other hand, this story has trouble explaining why "G" and "PT" are not better preserved in the GPT3 abstract plots just above. This is the first instance of "GPT" in the full passage, so the model can't rely on copies of these at earlier positions. That said, my sense of scale for "well-preservedness" is a wild guess, and these particular metrics may not be ideal for capturing it anyway.