Previously: Teaser: Hard-coding Transformer Models

Introduction

Transformer models are incredibly powerful for natural language tasks (and they are starting to find uses in many other fields of machine learning). Unfortunately, it is nigh-impossible to interpret what goes on inside them. OR IS IT???

I have found that I can, with a fair amount of effort, hard-code the weights of a transformer model in order to perform some very crude versions of linguistic tasks. So far I have achieved English-to-French translation (on a toy corpus of about 150 sentences), text classification (is a sentence grammatical or not? on a toy corpus of a couple hundred sentences), and sentiment analysis (again on a limited corpus). These results are obviously not impressive compared to the state of the machine learning field, but I am pretty sure that they can all be drastically scaled up with the investment of some time and energy. Unfortunately, I have a fairly demanding day job, and haven't found the time and energy yet.

All of this is done by inspection (no gradient descent!). The process is a lot like programming, although it is more difficult than programming, at least right now for me. I am fairly certain that better tools and better notation can be developed to make the process easier. It is also almost certainly possible to combine hard-coding with gradient descent approaches to be able to scale these methods up in a slightly less labor-intensive way.

I think that these ideas could prove useful in alignment research - if we understand how a language model works in excruciating detail, it seems drastically more likely that we will be able to reason about and predict various misunderstandings rooted in the ambiguity of language. Given that language is (arguably) a fully general means of interacting with an artificial intelligence, it seems plausible to me that this work is on the critical path to alignment.

Doneness Status

This post is a work-in-progress. I will be editing it as I go, mostly appending more content to the end, but I will also try to fix any errors or unclear parts as I notice them or commenters point them out.

So let's hard-code some neural computation! I have a very, very messy github repository where I've done my initial experiments, if you prefer to just jump into semi-working code. Otherwise, I will do my very best to explain the ideas from scratch in this post. I'm aiming this post at anyone who's willing to put in the work to understand it, so I'll try to at least give pointers to necessary background material, of which there is a fair amount.

What Can We Do Already?

Some primitive sentiment analysis using a Vanilla RNN:

Some very simple translation with a Transformer model:

 

How Should We Measure Success?

Before we explain how we get results, it seems worthwhile to talk about how to measure the performance of such a system. As in traditional machine learning, this can only be measured with respect to some dataset of inputs labeled with desired outputs. For a crude metric, we can look at the fraction of inputs that receive the desired output, versus the fraction that receive some other output. We can also look at more complex metrics, such as BLEU or ROUGE for sequence-to-sequence tasks.

In traditional machine learning, performance can only be (meaningfully) measured on a holdout set that was not used to train the algorithm. This is because performance tends to be much, much higher (if you're using the right architecture and hyperparameters for your task) on the training set (the set of data that was used to train the algorithm) than it will be for the test set (the set of data that has been held out). The whole purpose and challenge of machine learning, of course, is to build models that generalize to unseen data.

A similar phenomenon occurs in this work, where data that has been examined by the programmer and run through the algorithm and used to inform updates to the rules, will typically be data that the resulting network does disproportionately well on. After all, if you miss some edge case, but then see it in your testing, you have the opportunity to fix it. 

On the other hand, the programmer will presumably have a native command of at least one language, so it is at least possible for the programmer to anticipate some phenomena before seeing them in the "training data". Thus, it seems unfair to gradient descent and deep learning to compare accuracies in the low-data regime where I have been stuck so far by my lack of free time.

The ultimate ambition of this work would be to go toe-to-toe with a comparably-sized Transformer model trained in the traditional way on a modern-sized data set. This might require several people-years of focused effort though.

Some Useful Notation

The first thing we are going to do is introduce some very unconventional notation for vectors and matrices. (We won't need any more information about linear algebra than is contained in the Wikipedia article, but we will assume that you are either familiar with them or have paused and read that article.)

We will pick a set of "axes" that we will call "semes". (This word comes from semiotics, as do a few other terms we will use. I believe I'm using them in a way that is compatible with their technical meaning in semiotics, but feel free to think of this as a nonsense word that we are coining.) Each seme will be identified with a short string, often a word. So, we might have semes "wombat", "peregrine", and "pig". These play a role very similar to variable names in traditional programming, so we will generally choose them to be meaningful. Common semes that I actually use are "noun", "verb", etc.

We then will write vectors using these semes, for example  for the vector that is 1 in the direction  and -1 in the direction . We can also use coefficients, so that  denotes the vector that is 2.1 in the direction  and -3.2 in the direction . There are two ways of thinking about this - you can either think of the various semes as being completely orthogonal to each other, forming an orthonormal basis of whatever vector space we are in. Or you can think of them as arbitrary vectors that we are using as a (possibly overcomplete) basis. In general, both will be useful; I generally think of syntactic information as being best represented in a fully orthonormal basis, while semantic information makes much more sense as being drawn from a very overcomplete basis.

Matrices will be written in the form for a matrix that would be conventionally represented as .

In code, we will write them like this:

vec1: 2.1 pig -3.2 wombat

mat1: 1.1 pig>wombat +2.3 wombat>pig -4.5 pig>peregrine + 0.9 peregrine>peregrine

for the above vector and matrix.

As with vectors, there are two ways to think of the matrix notation. In the first way, the semes form an orthonormal basis, and we are just using them to identify which pairs of coordinates get which coefficient. But, we can also think of  as being 1.1 times the outer product of  and . This second view will not be necessary for the contents of this post, but it is necessary to understand some of the ways I envision being able to combine this work with gradient descent-based learning.

It is also worth pointing out that, if we multiply matrices and vectors with the vector on the left, that  actually maps the vector  to the vector . (Although, potentially confusingly,  maps  to .) For this reason, we will prefer left-multiplication in our neural networks later, because it makes this particular notation way easier to think with.

Tokenization and Word Embeddings

In deep NLP, the first couple steps are about getting rid of words and replacing them with inputs that can actually be understood by a deep network. The first step is to take a string and break it into some number of discrete chunks called "tokens". In principle, we could feed things in letter-by-letter, and people have gotten semi-decent results doing that in the past, but it's a lot less labor-intensive in this context to use full words as the unit of tokenization. This is actually a mild break from most Transformer models used today, which generally make use of a "subword vocabulary" which contains a mixture of whole words and parts of words like "ing" or "particul". 

Let's take an example sentence and tokenize it, just to be sure that we understand this process. Consider

The rain in Spain is mainly on the plain, while treefuls of weevils are gleefully evil.
["The", "rain", "in", "Spain", "is", "mainly", "on", "the", "plain", ",", "while", "treefuls", "of", "weevils", "are", "gleefully", "evil", "."]

Some things worth emphasizing:

  • We don't use tokens for whitespace (spaces, tabs, etc.)
  • Punctuation such as commas and periods will get a token of its own

Additionally, we will case-normalize our inputs by making everything lower-case. This cuts down on some repetitive work and is relatively common with deep models. We also include a special SOS (start of sentence) token and an EOS (end of sentence) token. So the above example should really look like this:

The rain in Spain is mainly on the plain, while treefuls of weevils are gleefully evil.
["SOS", "the", "rain", "in", "spain", "is", "mainly", "on", "the", "plain", ",", "while", "treefuls", "of", "weevils", "are", "gleefully", "evil", ".", "EOS"]

The second step in transforming a string of text into the sort of inputs that a deep network prefers is to do a "word embedding lookup". Here, each token is replaced by a fixed vector, so that we get a matrix of shape [num_tokens, word_embedding_dim]. Because the first axis (the "sequence dimension") is not semantically the same as the second axis (the "embedding dimension"), we will not use our special matrix notation, but will instead think of this as a list of vectors, one for each token.

So let's look at some word embeddings! Here are some pronouns (note that we're describing here a fragment of a flavor of English that includes the gender-neutral singular "they" in addition to the plural "they"):

    i: +nom +sg +1st +pro
    you: +nom +sg +2nd +pro
    he: +masc +nom +sg +3rd +pro
    she: +fem +nom +sg +3rd +pro
    it: +neut +sg +3rd +pro +expletive
    me: +acc +sg +1st +pro
    you: +sg +pl +2nd +pro
    him: +masc +acc +sg +3rd +pro
    we: +nom +pl +1st +pro
    they: +enby +nom +sg +pl +3rd +pro
    them: +enby +acc +sg +pl +3rd +pro
    us: +acc +pl +1st +pro
    them: +acc +pl +3rd +pro
    my: +gen +sg +1st +pro
    our: +gen +pl +1st +pro
    his: +masc +gen +sg +3rd +pro
    her: +fem +gen +acc +sg +3rd +pro
    its: +neut +gen +sg +3rd +pro
    their: +enby +gen +sg +pl +3rd +pro
    myself: +1st +reflexive +sg +pro
    ourselves: +1st +reflexive +pl +pro
    yourself: +2nd +reflexive +sg +pro
    yourselves: +2nd +reflexive +pl +pro
    himself: +3rd +reflexive +sg +masc +pro
    herself: +3rd +reflexive +sg +fem +pro
    itself: +3rd +reflexive +sg +neut +pro
    themselves: +3rd +reflexive +pl +enby +pro
    oneself: +3rd +reflexive +sg +pro

Here are some verbs:

    is: +be +verb +3rdsg +copula
    be: +be +verb +plain +copula
    was: +be +verb +preterite +copula +helper

    did: +do +helper +verb +preterite +agentlack +themeposs
    do: +do +helper +verb +plain +agentlack +themeposs
    does: +do +helper +verb +3rdsg +agentlack +themeposs
    have: +have +plain +helper +verb +agentposs +themeposs
    has: +have +3rdsg +helper +verb +agentposs +themeposs

    can: +can +plain +helper +modal
    could: +can +preterite +helper +modal
    may: +may +plain +3rdsg +helper +modal
    might: +may +helper +modal
    must: +must  +plain +helper +modal
    shall: +shall +plain +helper +modal
    should: +shall +preterite +helper +modal
    will: +will +plain +3rdsg +helper +modal
    would: +will +preterite +helper +modal
    ought: +ought +modal +helper +modal
    dare: +dare +modal +helper +modal

    accuse: +accuse +verb +plain +agentlack +themelack
    accused: +accuse +verb +preterite +agentlack +themelack
    accuses: +accuse +verb +3rdsg +agentlack +themelack
    appear: +appear +verb +plain +agentlack +complementposs
    appeared: +appear +verb +preterite +agentlack +complementposs
    appears: +appear +verb +3rdsg +agentlack +complementposs
    ate: +eat +verb +preterite +agentlack +patientposs
    beam: +beam +verb +plain +agentlack
    beamed: +beam +verb +preterite +agentlack
    beams: +beam +verb +3rdsg +agentlack
    bend: +bend +verb +plain +agentlack +patientposs
    bent: +bend +verb +preterite +agentlack +patientposs
    bends: +bend +verb +3rdsg +agentlack +patientposs
    bled: +bleed +verb +preterite +agentlack +patientposs
    bleed: +bleed +verb +plain +agentlack +patientposs
    bleeds: +bleed +verb +3rdsg +agentlack +patientposs
    blew: +blow +verb +preterite +agentlack +patientposs
    blow: +blow +verb +plain +agentlack +patientposs
    blows: +blow +verb +3rdsg +agentlack +patientposs
    braid: +braid +verb +plain +agentlack +patientlack
    braided: +braid +verb +preterite +agentlack +patientlack
    braids: +braid +verb +3rdsg +agentlack +patientlack
    breathe: +breathe +verb +plain +agentlack
    breathed: +breathe +verb +preterite +agentlack
    breathes: +breathe +verb +3rdsg +agentlack
    break: +break +verb +plain +agentlack
    breaks: +break +verb +3rdsg +agentlack
    broke: +break +verb +preterite +agentlack
    brush: +brush +verb +plain +agentlack +patientlack
    brushed: +brush +verb +preterite +agentlack +patientlack
    brushes: +brush +verb +3rdsg +agentlack +patientlack
    carve: +carve +verb +plain +agentlack +patientposs
    carved: +carve +verb +preterite +agentlack +patientposs
    carves: +carve +verb +3rdsg +agentlack +patientposs
    chase: +chase +verb +plain +agentlack +patientlack
    chased: +chase +verb +preterite +agentlack +patientlack
    chases: +chase +verb +3rdsg +agentlack +patientlack
    chuckle: +chuckle +verb +plain +agentlack
    chuckled: +chuckle +verb +preterite +agentlack
    chuckles: +chuckle +verb +3rdsg +agentlack
    came: +come +verb +preterite +agentlack
    come: +come +verb +plain +agentlack
    comes: +come +verb +3rdsg +agentlack
    cook: +cook +verb +plain +agentlack +patientposs
    cooked: +cook +verb +preterite +agentlack +patientposs
    cooks: +cook +verb +3rdsg +agentlack +patientposs
    cough: +cough +verb +plain +agentlack +patientposs
    coughed: +cough +verb +preterite +agentlack +patientposs
    coughs: +cough +verb +3rdsg +agentlack +patientposs
    cried: +cry +verb +preterite +agentlack
    cries: +cry +verb +3rdsg +agentlack
    cry: +cry +verb +plain +agentlack
    cut: +cut +verb +plain +preterite +agentlack +patientlack
    cuts: +cut +verb +3rdsg +agentlack +patientlack

You'll note something strange about verbs (and nouns and other "content" words): they almost all have one seme that's just themselves again! What gives? These are semantic semes, which are much harder to reason about than the other semes, which are syntactic. As we said earlier, syntactic semes should be thought of as an orthonormal basis of whatever size, but semantic semes are more usefully thought of as living in a small dimensional space, where they aren't mutually orthogonal. (But! all semantic semes should be thought of as perfectly orthogonal to all syntactic semes, and vice versa.) For the time being, I restrict myself to a relatively limited vocabulary and just use the semantic semes as if they were orthogonal. For grammaticality classification, which is the domain I have worked hardest on, the semantic semes are not particularly relevant. For translation, the only really important thing is that they are able to pick out the corresponding word in the target language. (Assuming there is a straightforward single-word translation in the target language, which there has mostly been in the toy examples I have considered thus far, but which in general is not the case.)

Let's embed a sentence! Consider:

The cat sat on the mat.

This tokenizes to 

["SOS", "the", "cat", "sat", "on", "the", "mat", ".", "EOS"]

Which then embeds to

SOS: +sos
the: +det
cat: +cat +sg +noun
sat: +sit +verb +preterite +agentlack
on: +on +prep
the: +det
mat: +mat +sg +noun
.: +punct +period
EOS: +eos

In mathematical notation, we would write this as 

For further clarity, let's give a gloss for each seme we're using:

semes: sos # start of sentence
.      eos # end of sentence
       det # determiner, a linguistic class that contains articles, demonstratives, and some other stuff
       cat # meowing animal (semantic seme)
       sg # singular in number
       noun # nouns, the class of object/concept words
       sit # sitting down (semantic seme)
       verb # verbs, the class of action words
       preterite # one of the past tenses in English
       agentlack # to be grammatical, this verb needs an agent
       on # the preposition (semantic seme)
       prep # prepositions, the class of words denoting relationships
       mat # something to sit on (semantic seme)
       punct # punctutation
       period # specifically this guy: .

Sentiment Analysis

Sentiment analysis refers to the task of extracting from a piece of natural language the overall sentiment that the speaker has towards whatever thing they're talking about. For instance, in a movie or product review, does the author recommend the movie or product? This is generally considered a pretty straightforward task for machine learning algorithms.

An extremely interpretable algorithm for sentiment analysis is given in VADER: A Parsimonious Rule-based Model for Sentiment Analysis of Social Media Text[PDF], by C. Hutto, E. Gilbert. (2014) We have implemented a similar algorithm inside of a two-layer vanilla RNN, which we will describe below. However, we wanted to first note that this algorithm is only a crude sketch of VADER, and its shortcomings should not be held against Hutto and Gilbert.

Why a vanilla RNN rather than the OMG SO MUCH BETTER LSTM? Well, vanilla RNN's are significantly simpler to understand, and their disadvantages (vanishing and exploding gradients, primarily) are only really relevant when you're actually using gradient descent! So let's do this the easy way and stick to a vanilla RNN.

First, a few preliminaries about the architecture we will be using:

Let  denote the output of layer  at time-step . Note that the superscript isn't an exponent, it's just a convenient place to put another index. There will be no exponents anywhere in this network; they are all superscripts. (Here time-step just means the index of the token. So time-step 1 will be the first token, time-step 2 will be the second token, and so on.)  will be a special initial state before we read any tokens, and we will set it to be the zero vector.  will be the output of the word-embedding layer, so it will just be the embedding of the -th token after we look it up.

We then define the recurrence (for ):

Here  is the good old logistic sigmoid function, and  is the identity matrix. (Practitioners will note that the use of the identity matrix here is some sort of residual-like connection.)

The logistic sigmoid function

We further define a pooling layer, and then a fully-connected or dense layer.

# our set of semes
semes: stop # "stop words", in this case anything not needed for sentiment analysis
       positive negative negation contrastive intensifier
       lessener intensepunctuation
       xa xb xc xd xe ya yb yc yd ye # a bunch of anonymous variables to represent intermediate computations

lexicon:
    ",": stop
    ".": stop
    "I'll": stop
    At: stop
    It: stop
    The: stop
    Today: stop
    VADER: stop
    a: stop
    all: stop
    and: stop
    are: stop
    at: stop
    book: stop
    by: stop
    characters: stop
    dialog: stop
    get: stop
    is: stop
    it: stop
    of: stop
    plot: stop
    the: stop
    was: stop
    FUNNY: positive
    GOOD: positive
    GREAT: positive
    HANDSOME: positive
    LOL: positive
    SMART: positive
    funny: positive
    good: positive
    great: positive
    handsome: positive
    lol: positive
    smart: positive
    SUX: negative
    bad: negative
    horrible: negative
    sux: negative
    uncompelling: negative
    "!": intensepunctuation
    "!!!": intensepunctuation
    very: intensifier
    VERY: intensifier
    uber: intensifier
    FRIGGIN: intensifier
    only: lessener
    kinda: lessener
    not: negation
    nor: stop
    isnt: negation
    Not: negation
    But: contrastive
    but: contrastive

rnn_layer1:
    A: positive>xa negative>ya positive>xb negative>yb
       negation>negation intensifier>intensifier lessener>lessener

    B: intensifier>xa intensifier>ya lessener>xb lessener>yb
       negation>negation 0.5intensifier>intensifier 0.5lessener>lessener

    bias: -xa -xb -ya -yb

rnn_layer2:
    A: xa>xc ya>yc xb>xd yb>yd negation>xc negation>xd negation>yc
       negation>yd positive>xe negation>xe negative>ye negation>ye

    B: '' # in yaml, which is the formatting language I use to type these programs, you need to do this to specify an empty string, which corresponds to the zero matrix

    bias: -xc -xd -yc -yd -xe -ye

dense1:
    C: positive>positive negative>negative 2xa>positive 0.25xb>positive
       ya>negative 0.25yb>negative xc>negative xd>negative yc>positive
       yd>positive -2xc>positive -2xd>positive -2yc>negative
       -2yd>negative xe>negative ye>positive -xe>positive -ye>negative
    c: ''


examples:  # examples modified from https://github.com/cjhutto/vaderSentiment
    - VADER is smart , handsome , and funny .
    - VADER is smart , handsome , and funny !
    - VADER is very smart , handsome , and funny .
    - VADER is VERY SMART , handsome , and FUNNY .
    - VADER is VERY SMART , handsome , and FUNNY !!!
    - VADER is VERY SMART , uber handsome , and FRIGGIN FUNNY !!!
    - VADER is not smart , handsome , nor funny .
    - The book was good .
    - It isnt a horrible book .
    - The book was only kinda good .
    - The plot was good , but the characters are uncompelling and the dialog is not great .
    - Today SUX !
    - Today only kinda sux ! But I'll get by , lol
    - Not bad at all

This is sufficient to generate the scores at the beginning of this post. The scores on the given examples are not all that inaccurate. Lots more work could obviously be done on this network, and I'd love it if people feel like working on this network or other later-discussed networks in the comments.

Transformer Overview

We give here a brief overview of the transformer architecture, for those unfamiliar with it. This will essentially be an accelerated recap of Jay Alammar's Illustrated Transformer, which I consider to be the best friendly introduction to the architecture.

Transformer in its full sequence-to-sequence glory has an encoder stack and a decoder stack. For text classification purposes, one generally just uses the encoder stack with a few simple layers at the end. The encoder stack is made up of a bunch of Transformer encoder blocks, each of which is the same architecturally, but each of which has its own learnable/settable weights that allow it to specialize and do its own particular task in the grand scheme of the network.

Stolen from Jay Alammar's Illustrated Transformer

The decoder stack is also made up of a series of architecturally identical Transformer layers, again each with their own learnable/settable weights that allow them to specialize into their own unique role. The decoder layers are similar to the encoder layers, but a little bit more complex.

So now let's dive inside the Transformer layer and see how they tick!

Stolen from Jay Alammar's Illustrated Transformer

So we have self-attention layers, encoder-decoder attention layers, and the feed-forward layers. The two types of attention layers are generally considered to be the innovative, "important" part of Transformer, but I have found in trying to hard-code weights for transformer (and years of research by many people has also found) that the feed-forward layers are crucial to being able to learn complex functions. (Well, that's not entirely true. I think someone in some paper managed to sort of smuggle the feed-forward layer into a computation that looks like self-attention over a learnable set of parameters, without losing any accuracy. But for our purposes the feed-forward layer is important.)

We'll next dive into the various layers and see how to hard-code their parameters. We'll start with the easiest layer to understand: the feed-forward layer.

Transformer Feed-Forward Layers

The standard transformer feed-forward layer can be described pretty simply as:

Here ReLU is a new kind of non-linearity, the "rectified linear unit".  and  are matrices called "weights", and  and  are vectors called "biases". "Parameters" refers to either weights or biases, although people will often refer to biases as "weights" also.

ReLU and its cousin GELU, stolen from Wikipedia

Like most non-linearities in deep learning, ReLU is an element-wise function, i.e., you apply it to each coordinate of the vector independently. , so if x is negative, , and otherwise .

So let's hard-code a feed-forward layer! This just requires picking values for .

semes: apple banana cherry durian yum yuck
mat1: apple>apple apple>yum banana>banana banana>yum cherry>yuck durian>yuck
bias1: -yum -yuck
mat2: apple>yum banana>yum -yum>yum yuck>yuck
bias2: '' # a zero vector

The semantics we are trying to encode here is that  OR  should be mapped to , while  AND  should be mapped to . This results in the following output:

You could argue that  should displace . This can be done, but it seems to require a third layer (I haven't given it much thought just now, so maybe it's doable in two layers), so it would have to be split across two different Transformer layers.

Preparing for Self-Attention: Positional Encodings

Before we can really dive into self-attention layers, it is useful to talk about positional encodings. This is a fairly technical aspect of the Transformer architecture, but I've found that it can be made fairly interpretable by thinking about it in the right way.

The traditional way to think about positional embeddings is just to read the following code, and then add some handwaving around "relative positional offsets can be encoded as linear combinations of sines and cosines". This is all correct, but it doesn't really yield enough understanding (for me at least) to hard-code things around the positional embeddings.

# Code from https://www.tensorflow.org/tutorials/text/transformer
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
  
  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  
  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    
  pos_encoding = angle_rads[np.newaxis, ...]
    
  return pos_encoding

I prefer to think of the various sines and cosines as the hands of a bunch of clocks running at different speeds. If we have 512 dimensions in our positional embeddings, then there will be 256 clocks. The sine is just the y-coordinate of the clock hand, and the cosine is the x-coordinate of the clock hand.

At time t=0, we look at the first word in the sequence. All clocks point in the same direction.

t=0

At time t=1, we look at the second word in the sequence. The slowest clock (all the way on the left) has advanced a tiny bit, while the fastest clock (all the way on the right) has advanced more.

t=1

At time t=4, the slowest clock has advanced a decent amount, while the fastest clock has advanced about a quarter rotation from where it started.

t=4

Now that we have a more grounded understanding, we can ask questions like, if I have the positional encoding of "very", how do I get the positional encoding of the next word? Well, the next word is one time-step further, so each clock should advance by the amount that that particular clock advances over one time step. We can "point" to that positional embedding by using a particular angle offset for each clock, which then translates into the specific linear combination of sines and cosines referenced earlier. Thus, I have an intentionally quite redundant way to refer to a specific number of time steps in the future.

The different speeds also provide me the ability to point to ranges of time-steps relatively easily. If I want to point to a very narrow range, I can use the fastest clock, which will pick out a very specific time-step, with little room for error. If I want to refer to a broader time-range, I can use a slower clock, which will have the effect that my pointing will be somewhat evenly spread over a wide range of time-steps. Since I have so many clocks to choose from, I can be quite precise in what sort of time interval I point at.

Finally, as a technical note, it should be understood that the range of clock speeds are chosen such that no clock does a full loop. Otherwise, our pointing might accidentally point at something that's done a full loop without understanding that that is occurring.

The notion of "pointing" will prove to be a very apt way of thinking about the mechanism of self-attention. Basically, words that interact with each other in the self-attention layer can be thought of as pointing at the word they interact with. 

Self-Attention

Finally, we come to the heart of the Transformer model, the self-attention layer. This can be mathematically expressed as

where  are matrices called the queries, key, and values, respectively. The \sqrt{d} factor, often called a "temperature", is important for gradient descent training, where it improves stability. (I would guess that it improves stability by making the attention matrix relatively balanced early on in training.) For us, we will have the opposite problem; it's much easier to think about approximately sparse matrices, so we will actually use something like this, with  being a fairly large scalar.

There's a lot to unpack here, so let's come up with a very simple example and work through it together.

Suppose we want to distinguish between these two sentences:

She saw a red apple. # grammatical
She saw a red. # not grammatical (ish, you can provide contexts in which it is natural)

We will try not to go too far down the rabbit hole of what makes something grammatical or not - here we're just trying to encode the simple rule (which does have limited exceptions) that adjectives can't just hang out without modifying anything, or being used with a linking verb, or in some other way being "licensed" by the other words in the sentence.

So we would like to be able to detect when an adjective seems to be modifying a noun in the usual way, versus when it is not. In my special programming language for self-attention layers, this can be done like this:

    H1a: # name of the head

        docstring: Modification layer. Specifically pairs of the form
                   Q K, where Q comes before K and Q modifies K. Q
                   must be an adjective or an adverb to use this rule.

        pos: # special notation for interacting with the positional encodings
            Q: 0
            K: +1

        x1: # adjective modifies noun: red apple
            Q: adjective
            K: noun

        x2: # adverb modifies verb: quickly write
            Q: adverb
            K: verb

        x3: # adverb modifies adjective: very slow
            Q: adverb
            K: adjective

        x4: # adverb modifies adverb: very slightly
            Q: adverb
            K: adverb

        x5: # everything else hits filler
            Q: verb noun det filler pro verb noun det filler pro
            K: filler

        int: noun>licensed verb>licensed adverb>licensed
             adjective>licensed

The queries and keys live in a small dimensional embedding space that I call key-space.  (Typical size in a standard transformer is 64 dimensions - much smaller than the 512/768/1024 hidden size.) There is one query vector for each token, and one key vector for each token. We take the dot product of every query vector with every key vector, and that gives us what are called "attention logits". Applying the softmax function to the attention logits gives us attention probabilities, which make up what is generally call the "attention matrix".  So let's compute all of these values for the above program and our example sentences to get a sense of how all of this works. 

Let's suppose that we have just run the word embedding lookup layer, so that we have the default embedding for each word, but no information yet about how the words are interacting with each other. That might look like this (omitting the positional embeddings for the time being):

SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red: +red +adjective
apple: +apple +noun +sg
EOS: +filler +eos
------------
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red +red +adjective
EOS: +filler +eos

The queries will then be computed as follows (again, omitting the positional components):

SOS: +2 x5
she: +2 x5
saw: +2 x5
a: +2 x5
red: +x1
apple: +2 x5
EOS: +2 x5
------------
SOS: +2 x5
she: +2 x5
saw: +2 x5
a: +2 x5
red: +x1
EOS: +2 x5

The keys will look like this (again, omitting the positional components):

SOS: +x5
she: 0 # 0-vector
saw: +x2
a: 0
red: +x3
apple: +x1
EOS: +x5
------------
SOS: +x5
she: 0 # 0-vector
saw: +x2
a: 0
red: +x3
EOS: +x5

Supposing that x1, x2, x3, x4, and x5 are orthonormal, this creates the following attention logits, where we omit zeros for brevity (and we're still omitting positional encodings):

# written in the form Q>K
SOS>SOS: 2
SOS>EOS: 2
she>SOS: 2
she>EOS: 2
saw>SOS: 2
saw>EOS: 2
a>SOS: 2
a>EOS: 2
red>apple: 1
apple>SOS: 2
apple>EOS: 2
EOS>SOS: 2
EOS>EOS: 2
------------
SOS>SOS: 2
SOS>EOS: 2
she>SOS: 2
she>EOS: 2
saw>SOS: 2
saw>EOS: 2
a>SOS: 2
a>EOS: 2
EOS>SOS: 2
EOS>EOS: 2

We will continue to ignore the positional encodings for the rest of this example, since they're not needed, and don't drastically change the attention matrix. (They would be needed in the case that there are two nouns, in which case the above weights would generate a tie between red pointing to apple and red pointing to the other noun, which we would want to resolve a certain way based on the rule that, in English, an adjective is close to the noun it modifies and before it, with some limited counter-examples.)

Now let's look at the values! For simplicity in our programs, we actually combine the value projection and what is usually called the out-projection. We call the combined quantity the "interpretant" (a term from semiotics). The interpretants for us are as follows:

SOS: 0
she: 0
saw: +licensed
a: 0
red: +licensed
apple: +licensed
EOS: 0
------------
SOS: 0
she: 0
saw: +licensed
a: 0
red: +licensed
EOS: 0

Multiplying the interpretants V by the attention matrix , we get the following outputs (for large )

SOS: 0
she: 0
saw: 0
a: 0
red: +licensed
apple: 0
EOS: 0
------------
SOS: 0
she: 0
saw: 0
a: 0
red: 0
EOS: 0

Using the residual connection that surrounds the self-attention layer, we then receive final outputs:

SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red: +red +adjective +licensed
apple: +apple +noun +sg
EOS: +filler +eos
------------
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red +red +adjective
EOS: +filler +eos

Thus, this self-attention layer has managed to compute the fact that the word "red" is licensed in the first, grammatical sentence, but not licensed in the second, questionable sentence. This can be used by downstream layers to declare the second sentence to be ungrammatical.

Phew!

More to Come!

We'll next see how to hard-code the weights of an actual Transformer. This will involve explaining the structures of Transformer layers, which will take a fair amount of time. In the meantime, please check out The Illustrated Transformer and Transformers from Scratch to get a head start on understanding them, or dive face-first into my 1800-line poorly-commented grammaticality classifier

New Comment
8 comments, sorted by Click to highlight new comments since:
[-]evhubΩ350

(Moderation note: added to the Alignment Forum from LessWrong.)

I'm confused by your notation for feed-forward layers.

What justifies re-using the same labels ("apple" etc.) for

  1. the coordinates of  
  2. the coordinates of , i.e. the basis in which the nonlinearity operates

?

If we want to express what the individual components of basis (2) mean in terms of the original space, we can either talk about which vectors/semes are mapped to them by , or which vectors/semes they get mapped to by .

But your labels don't correspond to either of these interpretations.  Instead, it looks like you are following rules of the form "the 4th component of every basis is called 'yum'," which leads you to label a coordinate "yum" even though it's neither mapped from "yum" by , nor mapped to "yum" by .

This notation also seems to require the basis (2) to have the same number of elements as (1), which generally will not be the case.  In transformers, (2) is typically larger by a factor of 4.   The logic of your example, meanwhile, can be expressed using a smaller nonlinearity basis of 3 elements:

with some arbitrary choices about which multiplicative constants to absorb into  and  vs. which to absorb into .

Thanks for your comments/questions, they're very insightful.

In general, there are as many encoding spaces in a Transformer as there are computational nodes, and a traditional Transformer will have little incentive to use the same semantics for any two of the spaces. (There's a little bit of an incentive because of the residual connections, which will (I think?) kind of tie the semantics of the various hidden-size-sized embeddings spaces.)

In particular, the middle layer of the dense-relu-dense feedforward layer is usually chosen to be significantly larger (4x) than the hidden size, and so it's not even theoretically possible to represent it using the same basis. I've found that it sometimes makes sense to use anonymous seme names like x1 x2 x3 etc in the feed-forward layer for this reason. In my experience so far I've found the feed-forward layers to be most useful for conjunctions and disjunctions - and there are a quadratic number of possible conjunctions and disjunctions of even two neurons, let alone 3 or 4. So it seems to me that this might give a tiny hint as to why people have found that the intermediate embedding space of the feed-forward layer needs to be so large.

Of course, there is a potentially huge gap between what I am clever enough to think of as a use for them and what good old gradient descent is clever enough to think of. We can only easily lower-bound the potential uses of them; upper-bounding the capabilities of a component will prove much more challenging.

I don't fully understand how the embeddings are done.

Can you spell out one of the examples? 

It would be helpful for me to see how the semes map to the actual matrix.

Added an example sentence and its embeddings. Will add more examples overall. Thanks for commenting!

Re: how this interacts with Alignment Research:

I think that these ideas could prove useful in alignment research - if we understand how a language model works in excruciating detail, it seems drastically more likely that we will be able to reason about and predict various misunderstandings rooted in the ambiguity of language.

Another use is for sanity checking existing interpretability techniques. For example, to check if particular neurons identified as curve detectors via interpretability techniques were indeed curve detectors, Chris Olah spent a few hours replacing the curve-detecting neurons with handwritten curve detector neurons. (He found that the interpretability techniques were able to give qualitatively similar results for both the original neurons and the handwritten neurons. More impressively, he also found that replacing the curve detecting neurons with his handwritten neurons was able to recover ~60% of the drop in accuracy compared to removing the original neurons entirely [reported in footnote 9].)

Very nice post. It is certainly useful to do this exercise of manually encoding language rules into the weights of a transformer in order to better understand the machinery involved.

"The ultimate ambition of this work would be to go toe-to-toe with a comparably-sized Transformer model trained in the traditional way on a modern-sized data set. This might require several people-years of focused effort though."

There is a long history of attempting to parse natural language with hand design rules and heuristics. The general consensus now is that hand engineering is insufficient, and some learning from data is necessary. To me it seems that this direction inherits the problems of these old fashioned language systems since you are codifying your own hand designed heuristics and rules into the network weights.

Do you see a way to introduce learning from data without sacrificing the interpretability that your approach provides?

There are a number of ways to combine this approach with learning, but I haven't had time to try any of them yet. Some ideas I have thought of:

  • Use hard-coded weights, plus some random noise, to initialize the weights of a transformer that you then train in the traditional fashion
    • Doesn't really help with interpretability or alignment, but might(???) help with performance
  • Write out all the weight and bias parameters as combinations of semes and outer products of semes, then learn seme embeddings by gradient descent
  • Semantic seme embeddings could be initialized from something like WordNet relationships, or learned with word2vec, to automate those guys
  • You could do smallish amounts of gradient descent to suggest new rules to add, but then add them by hand
    • Still would be very slow
  • Perhaps it is possible to start with a strong learned transformer and gradually identify human-legible rules that it is using, and replacing those specific parts with hard-coding
    • Could prove very difficult!!!
    • It seems almost certain to me that hard-coding weights would at least help us build the muscles needed to recognize what is going on, to the extent that we are able to