A short project on Mamba: grokking & interpretability
Epistemic status: I've worked on this project for ~20h, on my free time and using only a Colab notebook. Executive summary I trained a minimalistic implementation of Mamba (details below) on the modular addition task. I found that: 1. This non-transformer-based model can also exhibit grokking (i.e., the model learns to generalise after overfitting to the training data). 2. There are tools that we can import from neuroscience that can help us interpret how the network representation changes as grokking takes place over training epochs. Introduction Almost all of the Mechanistic Interpretability (MI) efforts I've seen people excited about and the great majority of the techniques I've learned are related to Transformer-based architectures. At the same time, a competitive alternative (Mamba) was recently introduced and later scaled. To me, when coupling these two facts together, a giant gap between capabilities and safety emerges. Thus, I think Mamba provides an interesting use case where we can test whether the more conceptual foundations of MI are solid (i.e., somewhat model-agnostic) and, therefore, whether MI can potentially survive another transformer-like paradigm shift on the race towards AGI. For a bit more of context, Mamba is based on a special version of State Space Models (SSMs): add another S (for Structured) and you have one of its essential components. The actual architecture is slightly more complex, as you can see in this awesome post, than the S-SSM layer, but for this project I wrote up a minimal implementation that could get the job done. A simple-yet-interesting enough task The task that the model has to solve is: given two input integers (x and y), return whether their sum is divisible by a big prime number (p=113, in this case). This is mapped into a setup that Autoregressive token predictors can deal with: one input example would consist of three tokens: 'x', 'y' and '=', and the only output token would be either '0' (if (x+y) mod p ≠0