Super cool, thanks!
Image interpretability seems mostly so easy because humans are already really good at interpreting 2D images with local structure. But thinking about this does suggest an idea for language model interpretability - how practical is it to find text that a) has high probability according to the prior distribution, b) strongly activates one attention head or feed-forward neuron or something, c) only weakly activates other parts of the transformer (within some reference class)? Probably this has already been tried somewhere and gotten middling results.
On priors, I wouldn't worry too much about c), since I would expect a 'super stimulus' for head A to not be a super stimulus for head B.
I think one of the problems is the discrete input space, i.e. how do you parameterize sequence that is being optimized?
One idea I just had was trying to fine-tune an LLM with a reward signal given by for example the magnitude of the residual delta coming from a particular head (we probably something else here, maybe net logit change?). The LLM then already encodes a prior over "sensible" sequences and will try to find one of those which activates the head strongly (however we want to operationalize that).
Image interpretability seems mostly so easy because humans are already really good
Thank you, this is a good point! I wonder how much of this is humans "doing the hard work" of interpreting the features. It raises the question of whether we will be able to interpret more advanced networks, especially if they evolve features that don't overlap with the way humans process inputs.
The language model idea sounds cool! I don't know language models well enough yet but I might come back to this once I get to work on transformers.
Very cool to see new people joining the interpretability field!
Some resource suggestions:
If you didn't know already, there is a TF2 port of Lucid, called Luna:
There is also Lucent, which is Lucid for PyTorch: (Some docs written by me for a slightly different version)
For transformer interpretability you might want to check out Anthropic's work on transformer circuits, Redwood Research's interpretability tool, or (shameless plug) Unseal.
[My project for AGI Safety Fundamentals programme (~Oct 2021). Code & pictures below.]
To me, reading about Feature Visualization felt like one of the most revealing insights about CNNs in the last years. Seeing the idea "this node finds eyes, this node finds mouths, the combination detects faces" (oversimplified) actually implemented by the CNN was a pleasant surprise, as in, it suggests we might actually understand how NNs work. There's more reading on the programme website here, I can highly recommend the articles by Chris Olah's group on distill.
Seeing this I think many of us immediately want to try this, and play around with it. There is of course the OpenAI Microscope to look at results, and the Lucid library, but I wanted to actually reproduce the idea myself without relying on a somewhat black box (big library / OpenAI Microscope).
Almost all tutorials I found however used Lucid, and this really cool write-up "How to visualize convolutional features in 40 lines of code" unfortunately starts with
from fastai.conv_learner import *
. In retrospective I think I could understand this now, but I didn't, and finding out which parts were fastai functions and what they do was rather tricky. I also didn't manage to install the required (older) version of fastai.So I decided to have a go myself, and, luckily, I found that "DeepDream" is based a very similar idea and I could adopt most code from this notebook from Google AI. This isn't actually too complicated, especially broken down to the bare minimum:
The whole code runs in about a minute on my laptop (no GPU).
#
for explanations:And here we go!
Looks like features. Now let's try to reproduce one of the OpenAI microscope images, node 4 of layer block4_conv1 -- here is my version:
And the OpenAI Microscope image:
Not identical, but clearly the same feature in both visualizations!
Finally here is a run with InceptionV3, just for the pretty pictures, this time starting with a non-random (black) image. And an animation of the image after every iteration.
Note: There's an optional bit to improve the speed (by about a factor of 2 on my laptop), just add this decorator in front of the
gradient_ascent
function:This is basically how far I got in the time, the code can be found on my GitHub (link). But I do plan to look at some more interpretability techniques (maybe something for transformers or RL?) or more general AGI Safety ideas in the future!
Feel free to post a comment or send me a message if you have any questions or anything really, happy to chat about these things!
I just want to thank the organizers of the AGI Safety Fundamentals programme again, for setting up the programme and all their support. I can highly recommend the programme, as well as the well-curated curriculum here if you just want to read through it yourself.