I'm finishing up my PhD on tensor network algorithms at the University of Queensland, Australia, under Ian McCulloch. I've also proposed a new definition of wavefunction branches using quantum circuit complexity.
Predictably, I'm moving into AI safety work. See my post on graphical tensor notation for interpretability. I also attended the Machine Learning for Alignment Bootcamp in Berkeley in 2022, did a machine learning/ neuroscience internship in 2020/2021, and also wrote a post exploring the potential counterfactual impact of AI safety work.
My website: https://sites.google.com/view/jordantensor/
Contact me: jordantensor [at] gmail [dot] com Also see my CV, LinkedIn, or Twitter.
Re. making this more efficient, I can think of a few options.
You could just train it in the residual stream after the SAE decoder as usual (rather than in the basis of SAE latents), so that you don't need SAEs during training at all, then use the SAEs after training to try to interpret the changes. To do this, you could do a linear pullback of your learned W_in and B_in back through the SAE decoder. That is, interpret (SAE_decoder)@(W_in), etc. Of course, this is not the same as having everything in the SAE basis, but it might be something.
Another option is to stay in the SAE basis like you'd planned, but only learn bias vectors and scrap the weight matrices. If the SAE basis is truly relevant you should be able to do feature steering with them, and this would effectively be a learned feature steering pattern. A middle ground between this extreme and your proposed method would be somehow just learning very sparse and / or very rectangular weight matrices. Preferably both.
Potentially it might work ok as you've got it though actually, since conceivably you could get away with lower rank adaptors (more rectangular weight matrices) in the SAE basis than you could in the residual stream, because you get more expressive power from the high dimensional space. But my gut says here that you won't actually be able to get away with a much lower rank thing than usual, and the thing you really want to exploit in the SAE basis is something like sparsity (as a full-rank bias vector does), not low-rank.
I'm keen to hear how you think your work relates to "Activation plateaus and sensitive directions in LLMs". Presumably should be chosen just large enough to get out of an activation plateau? Perhaps it might also explain why gradient based methods for MELBO alone might not work nearly as well as methods with a finite step size, because the effect is reversed if is too small?
Couldn't you do something like fit a Gaussian to the model's activations, then restrict the steered activations to be high likelihood (low Mahalanobis distance)? Or (almost) equivalently, you could just do a whitening transformation to activation space before you constrain the L2 distance of the perturbation.
(If a gaussian isn't expressive enough you could model the manifold in some other way, eg. with a VAE anomaly detector or mixture of gaussians or whatever)
There are many articles on quantum cellular automata. See for example "A review of Quantum Cellular Automata", or "Quantum Cellular Automata, Tensor Networks, and Area Laws".
I think compared to the literature you're using an overly restrictive and nonstandard definition of quantum cellular automata. Specifically, it only makes sense to me to write as a product of operators like you have if all of the terms are on spatially disjoint regions.
Consider defining quantum cellular automata instead as local quantum circuits composed of identical two-site unitary operators everywhere:
If you define them like this, then basically any kind of energy and momentum conserving local quantum dynamics can be discretized into a quantum cellular automata, because any two-site time and space independent quantum Hamiltonian can be decomposed into steps with identical unitaries like this using the Suzuki-Trotter decomposition.
This seems easy to try and a potential point to iterate from, so you should give it a go. But I worry that and will be dense and very uninterpretable:
I'm keen to see stuff in this direction though! I certainly think you could construct some matrix or tensor of SAE activations such that some decomposition of it is interpretable in an interesting way.
What is the original SAE like, and why discard it? Because it's co-evolved with the model, and therefore likely to seem more interpretable than it actually is?