Math & CS Undergraduate at MIT
https://amudide.github.io/
Thanks for the question -- is calculated over an entire batch of inputs, not a single . Figure 1 shows how the Switch SAE processes a single residual stream activation .
Hi Lee and Arthur, thanks for the feedback! I agree that routing to a single expert will force redundant features and will experiment with Arthur's suggestion. I haven't taken a close look at the router/expert geometry yet but plan to do so soon.
Thanks for the comment -- I trained TopK SAEs with various widths (all fitting within a single GPU) and observed wider SAEs take substantially longer to train, which leads me to believe that the encoder forward pass is a major bottleneck for wall-clock time. The Switch SAE also improves memory efficiency because we do not need to store all latents.
I'm currently working on implementing expert-parallelism, which I hope will lead to substantial improvements to wall-clock time.
Thanks for your comment! I believe your concern was echoed by Lee and Arthur in their comments and is completely valid. This work is primarily a proof-of-concept that we can successfully scale SAEs by directly applying MoE, but I suspect that we will need to make tweaks to the architecture.