Thanks for the elaboration, I'll follow up offline
Would you be able to elaborate a bit on your process for adversarially attacking the model?
It sounds like a combination of projected gradient descent and clustering? I took a look at the code but a brief mathematical explanation / algorithm sketch would help a lot!
Myself and a couple of colleagues are thinking about this approach to demonstrate some robustness failures in LLMs, it would be great to build off your work.
The figures remind me of figures 3 and 4 from Meta-learning of Sequential Strategies, Ortega et al 2019, which also study how autoregressive models (RNNs) infer underlying structure. Could be a good reference to check out!
.