This post walks through the math for a theorem. It’s intended to be a reference post, which we’ll link back to as-needed from future posts. The question which first motivated this theorem for us was: “Redness of a marker seems like maybe a natural latent over a bunch of parts of the marker, and redness of a car seems like maybe a natural latent over a bunch of parts of the car, but what makes redness of the marker ‘the same as’ redness of the car? How are they both instances of one natural thing, i.e. redness? (or ‘color’?)”. But we’re not going to explain in this post how the math might connect to that use-case; this post is just the math.
Suppose we have multiple distributions over the same random variables . (Speaking somewhat more precisely: the distributions are over the same set, and an element of that set is represented by values .) We take a mixture of the distributions: , where and is nonnegative. Then our theorem says: if an approximate natural latent exists over , and that latent is robustly natural under changing the mixture weights , then the same latent is approximately natural over for all .
Mathematically: the natural latent over is defined by , and naturality means that the distribution satisfies the naturality conditions (mediation and redundancy).The theorem says that, if the joint distribution satisfies the naturality conditions robustly with respect to changes in , then satisfies the naturality conditions for all . “Robustness” here can be interpreted in multiple ways - we’ll cover two here, one for which the theorem is trivial and another more substantive, but we expect there are probably more notions of “robustness” which also make the theorem work.
Trivial Version
First notion of robustness: the joint distribution satisfies the naturality conditions to within for all values of (subject to and nonnegative).
Then: the joint distribution satisfies the naturality conditions to within specifically for , i.e. which is 0 in all entries except a 1 in entry . In that case, the joint distribution is , therefore is natural over . Invoke for each k, and the theorem is proven.
... but that's just abusing an overly-strong notion of robustness. Let's do a more interesting one.
Nontrivial Version
Second notion of robustness: the joint distribution satisfies the naturality conditions to within , and the gradient of the approximation error with respect to (allowed) changes in is (locally) zero.
We need to prove that the joint distributions satisfy both the mediation and redundancy conditions for each . We’ll start with redundancy, because it’s simpler.
Redundancy
We can express the approximation error of the redundancy condition with respect to under the mixed distribution as
where, recall, .
We can rewrite that approximation error as:
Note that is the same under all the distributions (by definition), so:
and by factorization transfer:
In other words: if is the redundancy error with respect to under distribution , and is the redundancy error with respect to under the mixed distribution , then
The redundancy error of the mixed distribution is at least the weighted average of the redundancy errors of the individual distributions.
Since the terms are nonnegative, that also means
which bounds the approximation error for the redundancy condition under distribution . Also note that, insofar as the latent is natural across multiple values, we can use the value with largest to get the best bound for .
Mediation
Mediation relies more heavily on the robustness of naturality to changes in . The gradient of the mediation approximation error with respect to is:
(Note: it’s a nontrivial but handy fact that, in general, the change in approximation error of a distribution over some DAG under a change is .)
Note that this gradient must be zero along allowed changes in , which means the changes must respect . That means the gradient must be constant across indices :
To find that constant, we can take a sum weighted by on both sides:
So, robustness tells us that the approximation error under the mixed distribution can be written as
for any .
Next, we’ll write out as a mixture weighted by , and use Jensen’s inequality on that mixture and the logarithm:
Then factorization transfer gives:
Much like redundancy, if is the mediation error with respect to under distribution (note that we’re overloading notation, is no longer the redundancy error), and is the mediation error with respect to under the mixed distribution , then the above says
Since the terms are nonnegative, that also means
which bounds the approximation error for the mediation condition under distribution .
Nailed it, well done.