TLDR: we train LSTM model on algorithmic task of modulo addition and observe grokking. We fully reverse engeneer the algorithm learned and propose a way simpler equivalent version of the model that groks as well.
Reproducibility statement: all the code is available at the repo.
Introduction
This post is related to Neel Nanda's post and a detailed description of what grokking is can be found there. The short summary is that grokking is the phenomenon when model when being trained on an algorithmic task of relatively small size initially memorizes the trining set and then suddenly generalizes to the data it hasn't seen before. In our work we train a version of LSTM on modulo addition and observe grokking.
Model architecture, experiment settings and naming conventions
We train a version of LSTM with ReLU activation instead of tanh for getting hidden state from cell state ct. We use a linear layer without bias to get logits from h2. We use a slight trick which is using two linear layers without activation inbetween instad of one which makes grokking easier. We also use hidden dimension of N=p and one-hot encoded embeddings. The exact model arcitecture as well as naming is (very similar to Chris Olah's amazing post)
We use similar parameters to Neel's paper meaning extremely large weight decay and particular set of betas for Adam found with grid search. We train the model on a 30% of all pairs (a,b)mod113 and observe grokking:
This is a typical and well studied behaviour for transformer, however we are not aware of any examples of reverse enginnered problem with observed grokking during training.
Exact formulas for what LSTM does
This section is only relevant to the methodology we used to reverse engineer the model, not understanding algorithm itself as well as clarifying all the architectural details used, so feel free to skip this section.
Below are the formulas describing all intermediate activations and operations performed by the given model architecture. The model predicts a+bmod113 from inputsa,b∈{0,…,112}.
Where E(a) and E(b) are one hot encodes numbers and F is a constant ft.
Now we can see that the two summands are just functions of a and b and given Whhg those can be considered to have capacity to learn almost any representations independent of each other. Regardless of the interpretation the expression is just a sum of two vectors of dimension 113 one of which is a function of a only and second is a function of b only. This suggest the idea of considering them just like a lookup table of embeddings for a and b, hence we do it reducing the problem for the following:
We have two different lookup tables (different or at least they have seemingly no reason to be the same based on our formula). We then take two embeddigs apply ReLU and multiply by some matrix to get logits.
The first lookup table maps
a→(F(WihgE(a)+bg)+WhhgReLU(WihgE(a)+bg)+bg)
Second is
b→WihgE(b)
We extract all the weights from the original model only in a way described above to get a simplified model. The accuracy measured for the simplified model is only 5 percent lower than the original model hence we only need to interpret the simplified model.
An interesting observation is that if we were to train just simplified model with the same trick by breaking down the linear layer after ReLU into the product of two matrices without activation: W=W1W2 the simplified model would also grok leaning the same algorithm but it would not grok with just one matrix. The idea why this is the case will become more clear after understanding the algorithm.
Reverse engineering simplified model
First thing we notice is that both embedding tables are highly periodic but not only that, they are almost identical and the accuracy does not drop.
Periodicity shows up in FFT plots as well:
Interestingly, the average peak fourier magnitude over the training run looks like:
Which is explained by spasticity of frequencies used in W. Indeed, the embedding tables become periodic during the first few epochs and during the rest of the training the models learns to diversify firing activations for logits amongst many frequencies:
The algorithm itself
We first have to provide some facts about the embeddings we observed.
Three important points
Each column of u is a periodic function with a non-zero magnitude only at 0 and (−f,f) frequancy for some x. (reminder: embeddings themselves are rows of u).
For each integer s modulo 113 the matrix with i -th row equal to u[i]+u[s−i] - just the sums of embeddings corresponding to pairs of integers with a particular sum, staked vertically, has highly periodic columns. Not only that, but the frequency of the j-th column of matrix W is the same as the frequency of matrix constructed above for s=j.
Columns of W have just one dominant frequency, so highly periodic.
Below are some visualisation of the above points.
So the idea is if the columns present seem to be just a one magnitude wave, lets just write dows the formula for logits in assuming one dominant frequency and the expression for logits is not that complicated at this point.
Lets look at the logits:
Assuming observed periodicity and ignoring ReLU(which is reasonable to do because we care about periodic fucntions and as we care about scalar products of those and as matrices W and U have similar frequencies a natural thing to do is examing what is happening without ReLU)
Observe that constant term for columns of W is 0 and they have same frequancy as corresponding columns of u. Now s′ -th logit is
So after some algebra, we find that s′-th logit is c′s′+e2πijfk113((∑kc′ke2πifk113(s′−s))+∑kc′′ke2πifk113(s′+s))+c′′s′)
Now ck′ and ck′′ are of similar magnitude (being akxk+bkyk and akxk+bkyk respectively) hence the two sums above sum up to a constant when s′ is not equal to s or−s and fk spans all integers from 0 to 112 and the sum is just slightly noisy sum of roots of unity. The way model distinguishes them is by finding the correct cs′ which the model has to learn for each particular sum.
Summary
We have observed that grokking and learned that the algorithm learned is still to a large extend empirical: the model has to learn those 113 values of c′ separately and when it does the model groks. So the model chose to learn 113 parameters rather then 1132 and use the fact that (weighted) sum of roots of unity is comparatively small to the sum of 1's. This is confirmed by the fact that if we reshuffle the training data not containing particular sum, the model is more likely to get it wrong.
By Daniil Yurshevich, Nikita Khomich
TLDR: we train LSTM model on algorithmic task of modulo addition and observe grokking. We fully reverse engeneer the algorithm learned and propose a way simpler equivalent version of the model that groks as well.
Reproducibility statement: all the code is available at the repo.
Introduction
This post is related to Neel Nanda's post and a detailed description of what grokking is can be found there. The short summary is that grokking is the phenomenon when model when being trained on an algorithmic task of relatively small size initially memorizes the trining set and then suddenly generalizes to the data it hasn't seen before. In our work we train a version of LSTM on modulo addition and observe grokking.
Model architecture, experiment settings and naming conventions
We train a version of LSTM with ReLU activation instead of tanh for getting hidden state from cell state ct. We use a linear layer without bias to get logits from h2. We use a slight trick which is using two linear layers without activation inbetween instad of one which makes grokking easier. We also use hidden dimension of N=p and one-hot encoded embeddings. The exact model arcitecture as well as naming is (very similar to Chris Olah's amazing post)
We use similar parameters to Neel's paper meaning extremely large weight decay and particular set of betas for Adam found with grid search. We train the model on a 30% of all pairs (a,b)mod113 and observe grokking:
This is a typical and well studied behaviour for transformer, however we are not aware of any examples of reverse enginnered problem with observed grokking during training.
Exact formulas for what LSTM does
This section is only relevant to the methodology we used to reverse engineer the model, not understanding algorithm itself as well as clarifying all the architectural details used, so feel free to skip this section.
Below are the formulas describing all intermediate activations and operations performed by the given model architecture. The model predicts a+bmod113 from inputsa,b∈{0,…,112}.
Model Parameters and Dimensions
N = 113 (the modulo and also number of classes)
vocab_size=N
hidden_size=113
The input is a pair (a,b) with a,b∈{0,…,N−1}.
Input Encoding
We have an input sequence of length 2:
x_seq∈Rbatch×2,x_seq[i,0]=a, x_seq[i,1]=b.
We one-hot encode each integer:
xonehot[i,0,:]=one_hot(a),xonehot[i,1,:]=one_hot(b).
Thus:
xonehot∈Rbatch×2×vocab_size.
LSTM-like Cell Parameters
We have a single-layer LSTM-like cell with parameters:
Wih∈R4⋅hidden_size×vocab_size,bih∈R4⋅hidden_size
Whh∈R4⋅hidden_size×hidden_size,bhh∈R4⋅hidden_size
These define the input, forget, cell (g), and output gates at each timestep.
h0=0∈Rhidden_size,c0=0∈Rhidden_size.
Recurrent Step (t=0,1)
For(t∈{0,1}:xt=x_onehot[:,t,:]∈Rbatch×vocab_size
Compute gates:
gates=WihxTt+bih+WhhhT+bhh
which results in a batch×(4⋅hidden_size) vector.
Slice into four parts:
it=σ(gates[:,0:hidden_size])
ft=σ(gates[:,hidden_size:2⋅hidden_size])
gt=tanh(gates[:,2⋅hidden_size:3⋅hidden_size])
ot=σ(gates[:,3⋅hidden_size:4⋅hidden_size])
Update cell and hidden states:
ct=ft⊙ct−1+it⊙gt
However, instead of the standard LSTM update ht=ot⊙tanh(ct) , the model uses:
ht=ReLU(ct)
Step-by-Step for Two Steps
Att=0:
c1=f0c0+i0g0=i0g0 since c0=0
h1=ReLU(c1)
At t=1:
c2=f1c1+i1g1
h2=ReLU(c2)
After these two steps, h2 is the final hidden state for the sequence.
Final Two Linear Layers
The final output is computed by two linear layers without bias or nonlinearities:
intermediate=fc1(h2)=W1h2
logits=fc2(intermediate)=W2(intermediate)
Where:
W1∈Rhidden_size×hidden_size,W2∈Rvocab_size×hidden_size
No bias and no additional activations are applied here.
Final Output
logits∈Rbatch×vocab_size
The model’s output is the logits vector for each sample, which can be turned into a probability distribution via softmax.
Reverse engineering the model
Reducing the problem to simplified model
Turns out model prefers to learn way simpler structure and not use the full capability of LSTM. We find the following:
ReLU applied in the first cell to c1 when computing h1 can be removed with no accuracy drop.
Given the above the model architecture simplifies to:
This allows us to simplify the formula for C2 significantly:
C2=F(WihgE(a)+bg)+WhhgReLU(WihgE(a)+bg)+bg+WihgE(b)
Which can be written as
C2=(F(WihgE(a)+bg)+WhhgReLU(WihgE(a)+bg)+bg)+(WihgE(b))
Where E(a) and E(b) are one hot encodes numbers and F is a constant ft.
Now we can see that the two summands are just functions of a and b and given Whhg those can be considered to have capacity to learn almost any representations independent of each other. Regardless of the interpretation the expression is just a sum of two vectors of dimension 113 one of which is a function of a only and second is a function of b only. This suggest the idea of considering them just like a lookup table of embeddings for a and b, hence we do it reducing the problem for the following:
We have two different lookup tables (different or at least they have seemingly no reason to be the same based on our formula). We then take two embeddigs apply ReLU and multiply by some matrix to get logits.
The first lookup table maps
a→(F(WihgE(a)+bg)+WhhgReLU(WihgE(a)+bg)+bg)
Second is
b→WihgE(b)
We extract all the weights from the original model only in a way described above to get a simplified model. The accuracy measured for the simplified model is only 5 percent lower than the original model hence we only need to interpret the simplified model.
An interesting observation is that if we were to train just simplified model with the same trick by breaking down the linear layer after ReLU into the product of two matrices without activation: W=W1W2 the simplified model would also grok leaning the same algorithm but it would not grok with just one matrix. The idea why this is the case will become more clear after understanding the algorithm.
Reverse engineering simplified model
First thing we notice is that both embedding tables are highly periodic but not only that, they are almost identical and the accuracy does not drop.
Periodicity shows up in FFT plots as well:
Interestingly, the average peak fourier magnitude over the training run looks like:
Which is explained by spasticity of frequencies used in W. Indeed, the embedding tables become periodic during the first few epochs and during the rest of the training the models learns to diversify firing activations for logits amongst many frequencies:
The algorithm itself
We first have to provide some facts about the embeddings we observed.
Three important points
Columns of W have just one dominant frequency, so highly periodic.
Below are some visualisation of the above points.
So the idea is if the columns present seem to be just a one magnitude wave, lets just write dows the formula for logits in assuming one dominant frequency and the expression for logits is not that complicated at this point.
Lets look at the logits:
Assuming observed periodicity and ignoring ReLU(which is reasonable to do because we care about periodic fucntions and as we care about scalar products of those and as matrices W and U have similar frequencies a natural thing to do is examing what is happening without ReLU)
Observe that constant term for columns of W is 0 and they have same frequancy as corresponding columns of u. Now s′ -th logit is
∑kW[s′][k](u[j][k]+u[s−j][k])=
∑k(((xk+yki)e2πisfk113+(xk−yki)e−2πisfk113)(2c0+(a+bi)(e−2πisfk113+1)e2πisfkj113+(a−bi)(e2πisfk113+1)e−2πisfkj113))
So after some algebra, we find that s′-th logit is c′s′+e2πijfk113((∑kc′ke2πifk113(s′−s))+∑kc′′ke2πifk113(s′+s))+c′′s′)
Now ck′ and ck′′ are of similar magnitude (being akxk+bkyk and akxk+bkyk respectively) hence the two sums above sum up to a constant when s′ is not equal to s or−s and fk spans all integers from 0 to 112 and the sum is just slightly noisy sum of roots of unity. The way model distinguishes them is by finding the correct cs′ which the model has to learn for each particular sum.
Summary
We have observed that grokking and learned that the algorithm learned is still to a large extend empirical: the model has to learn those 113 values of c′ separately and when it does the model groks. So the model chose to learn 113 parameters rather then 1132 and use the fact that (weighted) sum of roots of unity is comparatively small to the sum of 1's. This is confirmed by the fact that if we reshuffle the training data not containing particular sum, the model is more likely to get it wrong.