Can't say much about transformers, but the tensor product definition seems off. There can be many elements in V⊗W that aren't expressible as v⊗w, only as a linear combination of multiple such. That can be seen from dimensionality: if v and w have dimensions n and m, then all possible pairs can only span n+m dimensions (Cartesian product), but the full tensor product has nm dimensions.
Here's an explanation of tensor products that I came up with sometime ago in an attempt to make it "click". Imagine you have a linear function that takes in two vectors and spits out a number. But wait, there are two natural but incompatible ways to imagine it:
f(a,b) + f(c,d) = f(a+c,b+d), linear in both arguments combined. The space of such functions has dimension n+m, and corresponds to Cartesian product.
f(a,b) + f(a,c) = f(a,b+c) and also f(a,c) + f(b,c) = f(a+b,c), in other words, linear in each argument separately. The space of such functions has dimension nm, and corresponds to tensor product.
It's especially simple to work through the case n=m=1. In that case all functions satisfying (1) have the form f(x,y)=ax+by, so their space is 2-dimensional, while all functions satisfying (2) have the form f(x,y)=axy, so their space is 1-dimensional. Admittedly this case is a bit funny because nm<n+m, but you can see how in higher dimensions the space of functions of type (2) becomes much bigger, because it will have terms for x1y1, x1y2, etc.
Ah yes that makes sense to me. I'll modify the post accordingly and probably write it in the basis formulation.
ETA: Fixed now, computation takes a tiny bit longer but hopefully still readable to everyone.
I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn't understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
Tensor product
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces V,W we can construct the tensor product space V⊗W as the span[1] of all matrices v⊗w, where v∈V,w∈W, with the property (v⊗w)ij=viwj [2]. We can equivalently define it as a vector space with basis elements eVi⊗eWj, where we used the basis elements of V and W respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) A:V→X,B:W→Y we can define A⊗B:V⊗W→X⊗Y, where each map simply operates on its own vector space, not interacting with the other:
(A⊗B)(v⊗w)=A(v)⊗B(w)
For more information on the tensor product, I recommend this intuitive explanation and the Wikipedia entry.
How does this connect to the attention-only transformer?
In the "attention-only" formulation of the transformer we can write the "residual" of a fixed head as AXWVWO, with the values weight matrix WV, the attention matrix A, the output weight matrix WO, and the current embeddings at each position X
Let E be the embedding dimension, L the total context length and D the dimension of the values, then we have that
Let's identify the participating vector spaces:
A maps from the "position" space back to the "position" space, which we will call P (and which is isomorphic to RL). Similarly, we have the "embedding" space E≅RE and the "value" space V≅RD.
It might become clear now that we can identify X with an element from P⊗E, i.e. that we can write X=Xij(ePi⊗eEj).
From that lense, we can see that right-multiplying X with WV is equivalent to multiplying with Id⊗WV, which maps an element from P⊗E to an element from P⊗V, by applying WV to the E-part of the tensor [3]:
(Id⊗WV)(X)=(Id⊗WV)∑ijXijePi⊗eEj=∑ijXijePi⊗WV(eEj)=∑ijXijePi⊗∑kWjkeVk=∑ik∑j(XijWjk)ePi⊗eVk=∑ik(XWV)ikePi⊗eVk=XWV
Identical arguments hold for WO and A, so that we get the formulation from the paper:
AXWOWV=(A⊗WOWV)⋅X
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form ABC you can re-write it as (A⊗C)⋅B (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).[4]
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set. ↩︎
I'm limiting myself to finite-dim spaces because that's what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn't add much to understanding the application in this paper ↩︎
Note that the 'linear map' that we use here is basically right multiplying with WV, so that it maps eEk↦WTVeEk ↩︎
I should note that this is also what is mentioned in the paper's introduction on tensor products, but it didn't click with me, whereas going through the above steps did. ↩︎