Any thoughts on potential connections with task arithmetic? (later edit: in addition to footnote 2)
with later small networks taking the outputs of earlier small networks as their inputs.
what's the distinction between two small networks connected in series with the first taking the output of the previous one as input and one big network? what defines the boundaries of the networks here?
I’m not sure I understand your question, but are you asking ‘in what sense are there two networks in series rather than just one deeper network’? The answer to that would be: parts of the inputs to a later small network could come from the outputs of many earlier small networks. Provided the later subnetwork is still sparsely used, it could have a different distribution of when it is used to any particular earlier subnetwork. A classic simple example is how the left-orientation dog detector and the right-orientation dog detector in InceptionV1 fire sort of independently, but both their outputs are inputs to the any-orientation dog detector (which in this case is just computing an OR).
I'm confused by the read-in bound:
Sure, each neuron reads from of the random subspaces. But in all but of those subspaces, the big network's activations are smaller than , right? So I was expecting a tighter bound - something like:
EDIT: Sorry, misunderstood your question at first.
Even if , all those subspaces will have some nonzero overlap with the activation vectors of the active subnets. The subspaces of the different small networks in the residual stream aren't orthogonal.
Ah, I think I understand. Let me write it out to double-check, and in case it helps others.
Say , for simplicity. Then . This sum has nonzero terms.
In your construction, . Focussing on a single neuron, labelled by , we have . This sum has nonzero terms.
So the preactivation of an MLP hidden neuron in the big network is . This sum has nonzero terms.
We only "want" the terms where ; the rest (i.e. the majority) are noise. Each noise term in the sum is a random vector, so each of the different noise terms are roughly orthogonal, and so the norm of the noise is (times some other factors, but this captures the -dependence, which is what I was confused about).
Tl;dr: We generalize the mathematical framework for computation in superposition from compressing many boolean logic gates into a neural network, to compressing many small neural networks into a larger neural network. The number of small networks we can fit into the large network depends on the small networks' total parameter count, not their neuron count.
Work done at Apollo Research. The bottom half of this post is just maths that you do not need to read to get the gist.
Introduction
Background
Anthropic's toy model of superposition shows how to compress many sparsely activating variables into a low dimensional vector space and then read them out again. But it doesn't show how to carry out computations on the compressed variables in their native format. The mathematical framework for computation in superposition makes a first stab at closing that gap. It shows how to compute boolean circuits in superposition.
What we do
We show how a network can perform any computations whatsoever in superposition. Specifically, we show how T small residual neural networks, each with n parameters that perform arbitrary tasks can be compressed into a single larger residual network that performs all T tasks, provided that the large network is only evaluated on sparse combinations of tasks — any particular forward pass only asks for k≪T tasks to be carried out. In the limit of T,n going to infinity, this larger network will require N=˜O(kTn) parameters[1].
Crucially, this means that the total number of small networks the larger network can implement scales approximately linearly with the number of weights in the network, not the number of neurons, as would be the case without computation in superposition. For example, if each small network uses m neurons per MLP layer and d dimensions in the residual stream, a large network with M neurons per MLP connected to a D-dimensional residual stream could implement about ˜O(MDkmd) small networks, not just
˜O(Mm). Qualitatively speaking, our construction works using same basic trick as the one for boolean circuits in superposition. We just generalize it from boolean AND gates to any operations the neural network could implement.
Generalising to circuits
While our derivation here assumes T networks carrying out unrelated tasks in parallel, nothing in the construction stops us from instead chaining the small networks in series, with later small networks taking the outputs of earlier small networks as their inputs. Therefore, the construction in this post can be thought of as a framework for representing arbitrary circuits in superposition.
Some very tentative implications, maybe?
Real neural networks probably don’t work exactly the way this construction does. It's made to be easy for us to prove things about it, not to be efficient in real life. The finite width of real networks might make other constructions better. We're also not dealing with potential correlations between the activations of different circuits, which might change the optimal setup even more. And ultimately, we don't actually know whether the structure of real-world datasets is sparse in the right way to incentivise learning sparsely activating circuits.
Neverthless, there may be some useful takeaways about real networks, so long as we don't forget that they come with a heavy pinch of salt:
Future work
The Construction
Suppose we have T small neural networks. For simplicity we will assume that each small network consists of L layers, with m neurons in each layer with a fixed elementwise nonlinearity, and a fixed residual stream width d. We require that these small networks are at least somewhat robust to noise: there is some magnitude of random noise ϵmax>0 that we can apply to all the preactivations of any of the small networks' neurons without changing downstream layer activations by more than some small δ.[4]
Then we can create a large network that is also L layers deep, with a residual stream width D≫d, M≫m neurons in each layer and the same activation functions, which can leverage superposition to compute the outputs of all $T$ neural networks in parallel.
This works even for D≪Td and M≪Tm, provided that only k≪T small neural networks are being passed a non-zero input vector on most forward passes. This large network will require on the order of N=˜O(kTn) parameters in total[5].
The core idea behind this construction is similar to that for computing many ANDs of binary inputs in superposition. There may be many other constructions that would also work, but we think that in the limit of very wide neural networks, all constructions would perform more or less the same, and yield the same fundamental limits for how many small networks can be superposed into a network with N parameters[6]. As with all constructions involving superposition, the key to the construction working out is in managing the size of the interference between separate small networks, and making sure that it does not become larger than the size of the signal — the correct output of each small network. In this construction, there are two sources of interference:
Read-in interference
Our T small networks have a combined Td≫D residual stream dimensions. So, activation vectors of different small networks in the large residual stream cannot be completely orthogonal. This means that when a particular small network is passed an input of 0 but other small networks are passed nonzero inputs, the value of the inputs that are read in by the weights that implement the first small network won't be exactly zero. In our construction, this read-in interference is what ends up dominating the constraints on how many small networks we can compute in a single large network.
At a high level, we manage read-in interference by making the residual stream width D larger so the overlap between small networks is smaller, and making the MLP width M larger so the read-in interference can be spread across more neurons.
Read-out interference
Our T small networks have a combined mT≫M neurons per layer. Naively, we could randomly assign every neuron in every small network to one neuron in the big network. But then, if two small networks that happened to share a neuron activated at the same time, that neuron would get conflicting inputs and misfire. So we could only carry out one of the T tasks at a time.
To make the small networks robust to these misfires, we introduce redundancy into the big network, representing each neuron in the small network with many neurons in the big network. This means that each neuron in the big network is assigned to even more small networks than if there was no redundancy, but this cost is worth it: we can now recover the value of any activation of any small network by averaging over the values of every neuron in the large neuron that represents it. If few enough small networks are active at once, then almost all neurons in the large network assigned to any particular small network's neuron will take on the correct value for that neuron, almost all of the time, and in the limit of M→∞, the difference between the value of a small network's neuron and the average of all the neurons in the large network that compute that small network will go to zero.
Maths
If you don't care about technical details, you can safely skip this section.
Let the input to the t-th small network be denoted by xt∈Rd and the activation vector of small network t in layer l for input xt by alt(xt) or simply alt.
Similarly, denote the activation vector for the large network in layer l by Al.
We also define a set of random matrices with orthonormal rows {Et∈RD×d}:
Et=(e1t⋯edt↓↓)
with eit∈RD satisfying eit⋅ejt=δij. Since the matrices are projection matrices to random d-dimensional subspaces of RD, their columns satisfy Et≠s(eit⋅ejs)2=O(1/D). These matrices define projections from the residual streams of each small network into a random subspace of the larger residual stream. What we want to prove is that if the number of xt that are nonzero is k≪T, then for all l=1,…,L, there exists terms δl satisfying ||δl||2≪||∑Tt=1Etalt||2, such that:
Al=∑Tt=1Etalt+δl.
We'll (sort-of) prove this using induction.
Embedding Matrix
The base case for the induction is just the embedding in layer 0. The input to the large network is the concatenated vector X=(x1,x2,…,xT)∈RTd. The embedding matrix[7] WE∈RD×Td is constructed by directly projecting each xt into the residual stream using Et, which we can do by stacking the projection matrices next to each other:
WE=(E1⋯ET).
Then, the residual stream activation vector at layer zero
A0:=WEX is equal to A0=∑TT=1Etxt as required.
Other layers
We'd now like to assume that Al=∑Tt=1Etalt+δl is satified in layer l−1, and demonstrate that it is satisfied in layer l. To do so, we need to work out what the matrices Wl,in,Wl,out should be.
Reading from the residual stream
To start, we need a way to compute the outputs of Wl,in1,…,Wl,inT∈Rd×n all at once with the larger matrix Wl,in∈RD×N. If we had D≥Td,N≥Tn we could do this by making Wl,in block diagonal, but we are looking for a construction with D≪Td,N≪Tn. To make progress, we start by noting that
Wl,intEt⊺Al−1=Wl,intal−1t+Wl,intEt⊺δl−1+Wl,int∑s≠tEt⊺Esal−1s,
where we have used that Et⊺Et=Id(d). We want the read-in interference
ϵl,int:=∑s≠tEt⊺Esal−1s
introduced to network t in layer l to be sufficiently small, staying below the ϵmax noise level we assume the subnetworks to be robust to. The justification for ϵl,int being small will be based based on the fact that for t≠s,Et⊺Es is approximately a matrix of gaussians with variance 1/D. Details are in Section Read-in interference.
Writing to the neurons
We can't just connect the outputs of this multiplication to neurons in layer l of the large network even if the interference is small. This is because mT≫M so we'd have to share neurons between many circuits and we wouldn't be able to tell if a neuron i fires due to circuit t activating, or some other circuit that connects to that neuron activating instead. Instead, we need to introduce some redundancy to the representations of the activations of each small network[8]. We do this by multiplying by a distributing matrix Vl∈RmT×M. This matrix is defined as follows:
For the t-th small network, the neurons that are in sets which are assigned a permutation matrix are called connected to that small network, and the neurons that are in sets assigned the zero matrix are called unconnected. We denote the set of all sets of neurons in the large network that are connected to the tth small network in layer l by Slt (a subset of the powerset of {1,…,M}), and the set of all neurons in the large network that are connected to the ith neuron of the tth small network in layer l by Slt,i. Every small network will on average connect its weights Wl,int to r=E[|Slt|]=logM sets of m neurons in the big network. So, we set
Wl,in=∑tVltWl,intEt⊺.
Writing back to the residual stream
To write back to the residual stream from the neurons, first we can recover the value of the activations of each small network by averaging all the neurons in the large network that are connected to that small network neuron. We do this by multiplying the activations of the big network with 1|Slt|(Vlt)⊺:
1|Slt|(Vlt)⊺ReLU(Wl,inAl)=ReLU(Wl,intalt)+ϵl,outt.
Then we can apply each Wl,outt to recover al+1t, and then we can embed these activations back into the residual stream using Et:
Wl,out=∑t1|Slt|EtWl,outt(Vlt)⊺.
If ϵl,outt is small enough (which requires ϵl,in to be small as well, then we are done, and Al will have the correct form.
Error analysis
Let a,w∈R+ be upper bounds on the L2 norm of the small networks' activations in the residual stream, and operator norm of their MLP input matrices, respectively:
||alt||2≤a∀l,t∈(1,…,T), ||Win,lt||op≤w∀l,t∈(1,…,T).
In the analysis below, we find that the L2 size of the total interference added to a subnet in an MLP layer will be
ϵ=O(wa√kTmdMDlogM).
For this noise to stay below the ϵmax we assumed the small networks to be robust to at every layer, our large network needs at least
N=O(w2a2ϵ2maxkTnlogM)
parameters in total. Any less than that, and the inteference will begin to overwhelm the signal. Assuming the noise ϵmax isn't larger than the maximum size of the small network's neuron activations, we'll have w2a2ϵ2max<1. So we need N=˜O(kTn) parameters in total.
Read-in interference
In this construction, we find that our total error term in dominated by read-in interference.
The noise from an activation vector als of a circuit s being multiplied by weight matrix Wint of a different circuit t will be
ϵl,int,s=WintEt⊺Esals.
The entries of the matrix Et⊺Es∈Rd×d will have approximate size O(1√D). Since the d entries of a row of Et⊺Es are randomly distributed, the entries of Et⊺Esals will then have average size O(√dD). So, the noise ϵl,int,s from activation als of small network s being partially projected into preactivations of neurons in small network t will be on the order of
ϵl,int,s=O(√dD||Win,lt||op||als||2).
On average, each neuron has Tp=TmMlogM weight rows of small networks connecting to it. Using ||als||≤a,||Win,lt||≤w, if there are k circuits active at a given time, the total read-in interference ϵl,int=∑s≠tϵl,int,s on the preactivation on any one neuron in any small network t will be bounded by
ϵl,int=O(wa√kTmdMDlogM)
because the noise sources are independent. This noise dominates the total error term.
Read-out interference
In our construction, we find that read-out interference ϵl,outt from multiple circuits using the same neuron is subdominant and vanishes in the limit of large networks. For the read-out of a small network from the MLP of the large network to become inaccurate, some fraction of the logM neurons playing the role of one neuron in the original small network have to all `misfire', activating when they shouldn't, or with incorrect magnitude even when they do fire. Since we assumed that our activation functions are Lipschitz continuous, we can bound any `misfire' to be smaller than some bound K∈R.
We'll assume that there is some critical fraction 0<c<1 which is the maximum number of misfires we can tolerate, which is dependent on the error tolerance of our small networks: clog(T) misfires would give us an error ϵl,outt,i≤clog(T)K on the read-out of neuron i in small network t, which we require to be smaller than the maximum error tolerance of the small networks ϵmax.
One neuron: Consider a specific neuron i in small network s. This neuron is assigned a set Sls,i of size approximately logM of neurons to compute it in the large network.
k=1: Suppose that only small network t≠s is active on the current forward pass. The chance of any circuit t connecting to a given neuron is p=mMlog(M). So, if c≪1, the probability that there are clogM misfirings in the set Sls,i will follow a binomial distribution:
P(x misfirings in Sls,i)=(logMclogM)(mlogMM)clogM(1−mlogMM)(1−c)logM.
The last factor is approximately equal to 1 and can be ignored.
k>1: Suppose there are k>1 small networks active at once. Each neuron in Sls,i can be used in multiple active networks. We can imagine a matrix with k rows and logM columns, with a 1 in the (i,j) position if the ith neuron in Sls,i is connected to the jth active small network, and a zero otherwise. The entries of this matrix are i.i.d Bernoulli random variables with probability p, and the number of nonzero entries in this matrix is the total number of misfirings in Sls,i. Again assuming c≪1, the probability Sls,i has clogM misfirings will be:
P(x misfirings in Sls,i)=(klogMclogM)(mlogMM)clogM.
Using Stirling's formula[9], we can write this as:
P(clogM misfirings in Sls,i)<(kmelogMMc)clogM.
We can approximate P(clogM+x misfirings in Sls,i) as a decaying geometric series in x, with initial value P0=P(clogM misfirings in Sls,i) and ratio r=Px+1Px≃klogMpclogM=kmlogMcM≪1.
Therefore, we have
P(at least clogM misfirings in Sls,i)=P01−r<(kmelogMMc)clogM.
One forward pass: We have Tm sets of neurons Sls,i. We want the chance of more than clogM misfirings for any of them on a forward pass to be vanishingly small for all c in the large width limit. That is, we want to scale M with the number of small networks T, the size of small networks m, and the number of active small networks k such that:
limM,T→∞Tm(eckmlogMM)clogM=0.
This condition is satisfied for any c≪1 so long as:
The read-in error already imposes MD=O(Tmkd), so the former condition is not an additional constraint, except in that it precludes making the residual stream exponentially wider than the MLP M. The latter condition is fulfilled if the small networks activate sparsely.
So, in the large width limit M→∞, ϵl,outt will vanish. Thus, the total error is dominated by ϵl,int.
Acknowledgements
Thanks to Dan Braun, Stefan Heimersheim, Lee Sharkey, and Bilal Chughtai for lots of discussions that shaped our thinking about this idea. Thanks also to Kaarel Hanni, Dmitry Vaintrob and Lawrence Chan for previous work that this idea builds on heavily, and for helping shape our thinking about this kind of thing.
N=˜O(kTn) basically means 'N=O(kTn) up to log factors'.
Put differently, we can't have an overcomplete basis of task vectors.
This limit is already suggested by information theory: Every operation we want the network to implement takes some minimum number of bits in its parameters to specify. So, in general, the minimum description length of the large network in bits can't be smaller than the minimum description lengths of the small networks summed together.
The more imprecision we're willing to tolerate in the final result, the larger ϵmax will be. If small networks vary in how noise robust they are, we pick the ϵmax of the least robust one to be conservative.
These simplifications primarily serve to avoid obfuscating the ideas in the construction. We are pretty confident that the derivations go through if you allow the number of neurons, residual stream width, and number of layers per small network to vary. That is, suppose we are given a set of neural networks indexed by t=1,…T. For the t-th network, denote the number of neurons per layer as mt, residual stream width dt, and number of layers ℓt. Then, there exists a large residual neural network with depth L, number of neurons per layer M, and residual stream width D which satisfies∀t∈{1,…,T}:mt≪M,dt≪D,ℓt≤L, and ∑tmt≫M,∑tdt≫D, which can compute the outputs of all T circuits in parallel by leveraging superposition.
We think some additional tinkering might remove the log term, and constant prefactors could likely be improved, but we doubt anything will break the limit N≥∑Ttnt. We can't specify more operations than we have bits to specify them in.
Using the convention of left multiplication by matrices.
This is essentially the same idea that is referred to as superpositional codes in this essay.
Which applies because p≪1, and the expected number of misfirings is pklogM=mklog2MM≪clogM.