I'd be curious if you have any ideas for how it can be applied in more advanced cases, e.g. what if we want to find the natural latents in Llama?
I expect the typical case will look like:
... which is not what this post is about.
The material in this post is useful mainly in cases where we want to be able to rule out any "better" natural latents, which is a somewhat atypical use case. It would be relevant, for instance, if I want to design a toy environment with known natural latents in which to train some system.
(Aside: this is something I updated about relatively recently; I had previously thought of the sort of thing this post is doing as the central use-case.)
Would the checks of the naturality conditions you have in mind primarily be empirical (e.g. sampling a bunch of data points and running some statistical independence checks), or might they just as often be mechanistic (e.g. not sure how that would work for complex models like Llama but e.g. for a Bayes net you obviously already have a factorization that makes robust model independence checks much easier)?
Asking because the idea of "in some model" (plus the desire for e.g. adversarial robustness) suggests to me that we'd want to have a more mechanistic idea of whether the naturality conditions hold, but they seem easier to check empirically.
So you’ve read some of our previous natural latents posts, and you’re sold on the value proposition. But there’s some big foundational questions still unanswered. For example: how do we find these natural latents in some model, if we don’t know in advance what they are? Examples in previous posts conceptually involved picking some latent out of the ether (like e.g. the bias of a die), and then verifying the naturality of that latent.
This post is about one way to calculate natural latents, in principle, when we don’t already know what they are. The basic idea is to resample all the variables once simultaneously, conditional on the others, like a step in an MCMC algorithm. The resampled variables turn out to be a competitively optimal approximate natural latent over the original variables (as we’ll prove in the post). Toward the end, we’ll use this technique to calculate an approximate natural latent for a normal distribution, and quantify the approximations.
The proofs will use the graphical notation introduced in Some Rules For An Algebra Of Bayes Nets.
Some Conceptual Foundations
What Are We Even Computing?
First things first: what even is “a latent”, and what does it even mean to “calculate a natural latent”? If we had a function to “calculate natural latents”, what would its inputs be, and what would its outputs be?
The way we use the term, any conditional distribution
(λ,x↦P[Λ=λ|X=x])
defines a “latent” variable Λ over the “observables” X, given the distribution P[X]. Together P[X] and P[Λ|X] specify the full joint distribution P[Λ,X]. We typically think of the latent variable as some unobservable-to-the-agent “generator” of the observables, but a latent can be defined by any extension of the distribution over X to a distribution over Λ and X.
Natural latents are latents which (approximately) satisfy some specific conditions, namely that the distribution P[X,Λ] (approximately) factors over these Bayes nets:
Intuitively, the first says that Λ mediates between the Xi’s, and the second says that any one Xi gives approximately the same information about Λ as all of X. (This is a stronger redundancy condition than we used in previous posts; we’ll talk about that change below.)
So, a function which “calculates natural latents” takes in some representation of a distribution (x↦P[X]) over “observables”, and spits out some representation of a conditional distribution (λ,x↦P[Λ=λ|X=x]), such that the joint distribution (approximately) factors over the Bayes nets above.
For example, in the last section of this post, we’ll compute a natural latent for a normal distribution. The function to compute that latent:
Why Do We Want That, Again?
Our previous posts talk more about the motivation, but briefly: two different agents could use two different models with totally different internal (i.e. latent) variables to represent the same predictive distribution P[X]. Insofar as they both use natural latents, there’s a correspondence between their internal variables - two latents over the same P[X] which both approximately satisfy the naturality conditions must contain approximately the same information about X. So, insofar as the two agents both use natural latents internally, we have reason to expect that the internal latents of one can be faithfully translated into the internal latents of the other - meaning that things like e.g. language (between two humans) or interpretability (of a net’s internals to a human) are fundamentally possible to do in a robust way. The internal latents of two such agents are not mutually alien or incomprehensible, insofar as they approximately satisfy naturality conditions and the two agents agree predictively.
Approximate “Uniqueness” and Competitive Optimality
There will typically be more than one different latent which approximately satisfies the naturality conditions (i.e. more than one conditional distribution (λ,x↦P[Λ=λ|X=x]) such that the joint distribution of Λ and X approximately factors over the Bayes nets in the previous section). They all “contain approximately the same information about X”, in the sense that any one approximate natural latent approximately mediates between X and any other approximate natural latent. In that sense, we can approximately talk as though the natural latent is unique, for many purposes. But that still leaves room for better or worse approximations.
When calculating, we’d ideally like to find a natural latent which is a “best possible approximate natural latent” in some sense. Really we want a pareto-best approximation, since we want to achieve the best approximation we can on each of the naturality conditions, and those approximations can trade off against each other.
… but there’s a whole pareto surface, and it’s a pain to get an actual pareto optimum. So instead, we’ll settle for the next best thing: a competitively optimal approximate natural latent. Competitive optimality means that the natural latent we’ll calculate approximates the naturality conditions to within some bounds of any pareto-best approximate natural latent; it can only do so much worse than “the best”. Crucially, competitive optimality means that when we don’t find a very good approximate natural latent, we can rule out the possibility of some better approximate natural latent.
Strong Redundancy
Our previous posts on natural latents used a relatively weak redundancy condition: all-but-one Xi gives approximately the same information about Λ as all of X. (Example: 999 rolls of a biased die give approximately the same information about the bias as 1000 rolls.) The upside of this condition is that it’s relatively general; the downside is that it gives pretty weak quantitative bounds, and in practice we’ve found that a stronger redundancy condition is usually more useful. So in this post, we’ll require “strong redundancy”: any one Xi must give approximately the same information about Λ as all of X. (Example: sticking a thermometer into any one part of a bucket of water at equilibrium gives the same information about the water’s temperature.)
If we want to turn weak redundancy into strong redundancy, e.g. to apply the methods of this post to the biased die example, the usual trick is to chunk together the Xi’s into two or three chunks. For instance, with 1000 die rolls, we could chunk together the first 500 and the second 500, and either of those two subsets gives us roughly the same information about the bias (insofar as 500 rolls is enough to get a reasonably-precise estimate of the bias).
Conceptually, with strong redundancy, all of the X-relevant information in a natural latent is represented in every single one of the Xi’s. For purposes of establishing that e.g. natural latents of two different agents contain the same information about X, that means strong redundancy gives us “way more than we need” - we only really need strong redundancy over two or three variables in order to establish that the latents “match”.
The Resampling Construction
We start with a distribution P[X] over the variables X1…Xn. We want to construct a latent which is competitively optimal - i.e. if any latent exists over X1…Xn which satisfies the natural latent conditions to within some approximation, then our latent satisfies the natural latent conditions to within some boundedly-worse approximation (with reasonable bounds). We will call our competitively optimal latent X′ (pronounced “X prime”), for reasons which will hopefully become clear shortly.
Here’s how we construct X′.
Take “X”, then add an apostrophe, “‘“, like so -> X’… and that was how David died. Anyway, to construct X′:Mathematically, that means the defining distribution of the latent X′ is
P[X′=x′|X=x]=∏iP[Xi=x′i|X¯i=x¯i]
Conceptually, we can think of this as a single resample step of the sort one might use for MCMC, in which we resample every variable simultaneously conditional on all other variables.
Example: suppose X1 is 500 rolls of a biased die, X2 is another 500 rolls of the same die, and X3 is yet another 500 rolls of the same die. Then to calculate X′1. I sample the bias of the die conditional on the 1000 rolls in X2 and X3, then generate 500 new rolls of a die with my newly-sampled bias, and those new rolls are X′1. Likewise for X′2 and X′3 (noting that I’ll need to sample a new bias for each of them). Then, I put all those 1500 rolls together to get X′.
Why would X′ be a competitively optimal natural latent? Intuitively, if there exists a natural latent (with strong redundancy), then each Xi encodes the value of the natural latent (approximately) as well as some “noise” independent of all the other Xi’s. When we resample, the natural latent part is kept the same, but the noise is resampled to be independent of the other Xi’s. So, the only information which X′ contains about X is the value of the natural latent. Of course, that story doesn’t give approximation bounds; that’s what we’ll need all the fancy math for.
In the rest of this section, we’ll show that X′ satisfies the naturality conditions competitively optimally: if there exists any latent Λ which is natural to within some approximation, then X′ is natural to within a boundedly-worse approximation.
Theorem 1: Strong Redundancy => Naturality
Normally, a latent must approximately satisfy two (sets of) conditions in order to be natural: mediation, and redundancy. The latent must encode approximately all the information correlating the Xi’s (mediation), and each Xi must give approximately the same information about the latent (redundancy). Theorem 1 says that, for X′ specifically, the approximation error on the (strong) redundancy conditions upper bounds the approximation error on the mediation condition. So, for X′ specifically, “redundancy is all you need” in order to establish naturality.
Some of the proof will be graphical, but we’ll need to start with one key algebraic step. The key step is this:
DKL(P[X,X′]||P[X′,Xj]P[X¯j|Xj])=E[lnP[X¯j|Xj,X′]−lnP[X¯j|Xj]]
=E[lnP[X¯j|Xj,X′]−lnP[X¯j|X′j]]
=DKL(P[X,X′]||P[X′,Xj]P[X¯j|X′j])
The magic piece is the replacement of E[lnP[X¯j|Xj]] with E[lnP[X¯j|X′j]]; this is allowed because, by construction, (X¯j,Xj) have the exact same joint distribution as (X¯j,X′j). Graphically, that tells us:
Note that the left diagram is the strong redundancy condition for Xi.
The rest of the proof is just a bookkeeping step:
So X′ mediates between Xi and X¯i, for all i.
Theorem 2: Competitive Optimality
To prove competitive optimality, we first assume that there exists some latent Λ over X which satisfies the (strong) natural latent conditions to within some bounds. Using that assumption, we want to prove that X′ satisfies the (strong) natural latent conditions to within some not-much-worse bounds. And since Theorem 1 showed that, for X′ specifically, the strong redundancy approximation error bounds the mediation approximation error, all that’s left is to bound the strong redundancy approximation error for X′.
Outline:
Step 1: Xi Mediates Between Xj and X′i
The two naturality conditions (just one of the N redundancy conditions) of Λ over X easily show that Xi mediates between Xj and Xk (i≠j≠k). The equivalence of P[X] and P[Xi,Xj,X′k] (by construction of X′) allows for replacing X′k in the factorization with the X′k version. Then, we get the result we were looking for.
Step 2: X′ has Weak Redundancy over X
In the first line, we use the definition of X′ and the result from Step 1 to establish mediation of X1 between X2 and X′3 and so we can remove the outgoing edge X2→X′3. In the second line, we do the same thing for the remaining outgoing edge of X2, establishing X2 as unconditionally independent of X′. Having done so, (X1,X3) trivially mediates between X2 and X′.
Step 3: Λ Mediates between X and X′
The intermediates here are much more easily understood in graphical form, but in words: In lines 1 and 2, we combine result of Step 2 with the mediation condition of Λ and the definition of X′ to stitch together a combined factorization of the joint distribution of X, X′, and Λ where X mediates between Λ and X′, and in particular it’s the components (X1,X2) which mediate while X3 is independent conditional on Λ. With some minor bookkeeping, we flip the arrow between (X1,X2) and X′, and add an arrow from X′ to Λ. Since this produces no cycles nor colliders, this is a valid move.
In line three, we use the result of line 2 in all 3 permutations of the X components and Frankenstein the graphs together to show that, since each component of X has Λmediating between it and X′, Λ mediates between all of X and X′.
Step 4: Strong Redundancy of X′
Using the result from Step 3, along with the strong redundancy of Λ allows us to stitch the graphs together and finally obtain our desired result: Strong Redundancy of X′.
The full proof of (Approximate) Natural Latent => (Approximate) Strongly Redundant X′ in one picture:
Can You Do Better?
Note that the bounds derived here are fine in a big-O sense, but a little… unaesthetic. The numbers 9 and 7 are notably not, like, 1 or 2 or even 3. Also, we had to assume a strong approximate natural latent over at least three variables in order for the proof to work; the proof actually doesn’t handle the 2-variable case!
Could we do better? In particular, a proof which works for two variables would likely improve on the bounds considerably. We haven’t figured out how to do that yet, but we haven’t spent that much time on it, and intuitively it seems like it should work.
So if you’re good at this sort of thing, please do improve on our proof!
Empirical Results (Spot Check)
As an empirical check, we coded up relevant calculations for normal distributions. We are not going to go through all the linear algebra in this post, but you can see the code here, if for some reason you want to inflict that upon yourself. The main pieces are:
The big thing we want to check here is that X′ in fact yields approximations within the bounds proven above, when we start from a distribution with a known approximate natural latent.
The test system:
θ itself is the known approximate natural latent, with strong redundancy when α is relatively small. We compute X′ from only the distribution P[X], and then the table below shows how well the naturality bounds compare to the bounds for our known natural latent θ.
N=24, alpha=0.5
(All numbers in the above table are DKL’s, measured in bits.)
Testing the actual summary stats / parameters (Known Latent) which generated the distributions as a natural latent, we see that the mediation condition is satisfied perfectly (numerically zero), while a strong redundancy condition (just for one Xi, randomly chosen) is ~2.139. So it looks like there is indeed at least one approximate natural latent in this system.
Calculating X′ and then testing it against the naturality conditions, we see that the mediation condition is no longer numerically zero but remains small. The strong redundancy condition (again, for one randomly chosen Xi) is ~2.131 which is a hair better than the known latent. Overall, naturality of X′ is well within the bounds given by the theorems. Note that the theorems now allow us to rule out any approximate natural latent for this system with 9 ϵmed + 7 ϵred < 2.13 bits.
Nice. 😎