Exciting stuff, thanks!
It's a little surprising to me how bad the logit lens is for earlier layers.
I think working on mechanistic intepretability in a variety of domains, architectures, and modalities seems like a reasonable research diversification bet.
However, it feels pretty odd to me to describe branching out into other modalities as crucial when we haven't yet really done anything useful with mechanistic interpretability in any domain or for any task.
You say:
With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.
And on X/twitter:
Frontier models are multimodal, and it's increasingly clear that mechanistic interpretability can't only study language models.
But, I feel like the situation is relatively analogous to:
Fusion power plants will need to be built in many countries, and it's increasing clear that fusion power plant construction can't only study building fusion power in the US.
Like yeah, you'll eventually need to handle non-language modalities and you should probably sanity check that they aren't key additional blockers with the methodology, but also why would there be key methodologies that mean it can solve our problems in the language case but note the vision/multimodal case? And the main obstacle is demonstrating basic technical feasibility, not branching out?
Again, I'd like to stress that studying a variety of cases with mech interp seems like a reasonable research diversification bet.
(And I don't want to be the language police here, just pushing back a bit on the implicit vibes.)
There is another argument that could be made for working on other modalities now: there could be insights which generalize across modalities, but which are easier to discover when working on some modalities vs. others.
I've actually been thinking, for a while now, that people should do more image model interprebility for this sort of reason. I never got around to posting this opinion, but FWIW it is the main reason I'm personally excited by the sort of work reported here. (I have mostly been thinking about generative or autoencoding image models here, rather than classifiers, but the OP says they're building toward that.)
Why would we expect there to be transferable insights that are easier to discover in visual domains than textual domains? I have two thoughts in mind:
First thought:
The tradeoff curve between "model does something impressive/useful that we want to understand" and "model is conveniently small/simple/etc." looks more appealing in the image domain.
Most obviously: if you pick a generative image model and an LLM which do "comparably impressive" things in their respective domains, the image model is going to be way smaller (cf.). So there are, in a very literal way, fewer things we have to interpret -- and a smaller gap between the smallest toy models we can make and the impressive models which are our holy grails.
Like, Stable Diffusion is definitely not a toy model, and does lots of humanlike things very well. Yet it's pretty tiny by LLM standards. Moreover, the SD autoencoder is really tiny, and yet it would be a huge deal if we could come to understand it pretty well.
Beyond mere parameter count, image models have another advantage, which is the relative ease of constructing non-toy input data for which we know the optimal output. For example, this is true of:
By contrast, in language modeling and classification, we really have no idea what the optimal logits are. So we are limited to making coarse qualitative judgments of logit effects ("it makes this token more likely, which makes sense"), ignoring the important fine-grained quantitative stuff that the model is doing.
None of that is intrinsically about the image domain, I suppose; for instance, one can make text autoencoders too (and people do). But in the image domain, these nice properties come for free with some of the "real" / impressive models we ultimately want to interpret. We don't have to compromise on the realism/relevance of the models we choose for ease of interpretation; sometimes the realistic/relevant models are already convenient for interpretability, as a happy accident. The capabilities people just make them that way, for their own reasons.
The hope, I guess, is that if we came pretty close to "fully understanding" one of these more convenient models, we'd learn a lot of stuff a long the way about how to interpret models in general, and that would transfer back to the language domain. Stuff like "we don't know what the logits should be" would no longer be a blocker to making progress on other fronts, even if we do eventually have to surmount that challenge to interpret LLMs. (If we had a much better understanding of everything else, a challenge like that might be more tractable in isolation.)
Second thought:
I have a hunch that the apparent intuitive transparency of language (and tasks expressed in language) might be holding back LLM interpretability.
If we force ourselves to do interpretability in a domain which doesn't have so much pre-existing taxonomical/terminological baggage -- a domain where we no longer feel it's intuitively clear what the "right" concepts are, or even what any breakdown into concepts could look like -- we may learn useful lessons about how to make sense of LLMs when they aren't "merely" breaking language and the world down into conceptual blocks we find familiar and immediately legible.
When I say that "apparent intuitive transparency" affects LLM interpretability work, I'm thinking of choices like:
In both of these lines of work, there's a temptation to try to parse out the LLM computation into operations on parts we already have names for -- and, in cases where this doesn't work, to chalk it up either to our methods failing, or to the LLM doing something "bizarre" or "inhuman" or "heuristic / unsystematic."
But I expect that much of what LLMs do will not be parseable in this way. I expect that the edge that LLMs have over pre-DL AI is not just about more accurate extractors for familiar, "interpretable" features; it's about inventing a decomposition of language/reality into features that is richer, better than anything humans have come up with. Such a decomposition will contain lots of valuable-but-unfamiliar "frobnoloid"-type stuff, and we'll have to cope with it.
To loop back to images: relative to text, with images we have very little in the way of pre-conceived ideas about how the domain should be broken down conceptually.
Like, what even is an "interpretable image feature"?
Maybe this question has some obvious answers when we're talking about image classifiers, where we expect features related to the (familiar-by-design) class taxonomy -- cf. the "floppy ear detectors" and so forth in the original Circuits work.
But once we move to generative / autoencoding / etc. models, we have a relative dearth of pre-conceived concepts. Insofar as these models are doing tasks that humans also do, they are doing tasks which humans have not extensively "theorized" and parsed into concept taxonomies, unlike language and math/code and so on. Some of this conceptual work has been done by visual artists, or photographers, or lighting experts, or scientists who study the visual system ... but those separate expert vocabularies don't live on any single familiar map, and I expect that they cover relatively little of the full territory.
When I prompt a generative image model, and inspect the results, I become immediately aware of a large gap between the amount of structure I recognize and the amount of structure I have names for. I find myself wanting to say, over and over, "ooh, it knows how to do that, and that!" -- while knowing that, if someone were to ask, I would not be able to spell out what I mean by each of these "that"s.
Maybe I am just showing my own ignorance of art, and optics, and so forth, here; maybe a person with the right background would look at the "features" I notice in these images, and find them as familiar and easy to name as the standout interpretable features from a recent LM SAE. But I doubt that's the whole of the story. I think image tasks really do involve a larger fraction of nameless-but-useful, frobnoloid-style concepts. And the sooner we learn how to deal with those concepts -- as represented and used within NNs -- the better.
Thanks for your comment. Some follow-up thoughts, especially regarding your second point:
There is sometimes an implicit zeitgeist in the mech interp community that other modalities will simply be an extension or subcase of language.
I want to flip the frame, and consider the case where other modalities may actually be a more general case for mech interp than language. As a loose analogy, the relationship between language mech interp and multimodal mech interp may be like the relationship between algebra and abstract algebra. I have two points here.
Alien modalities and alien world models
The reason that I’m personally so excited by non-language mech interp is due to the philosophy of language (Chomsky/Wittgenstein). I’ve been having similar intuitions to your second point. Language is an abstraction layer on top of perception. It is largely optimized by culture, social norms, and language games. Modern English is not the only way to discretize reality, but the way our current culture happens to discretize reality.
To present my point in a more sci-fi way, non-language mech interp may be more general because now we must develop machinery to deal with alien modalities. And I suspect many of these AI models will have very alien world models! Looking at the animal world, animals communicate with all sorts of modalities like bees seeing with ultraviolet light, turtles navigating with magnet fields, birds predicting weather changes with barometric pressure sensing, aquatic animals sensing dissolved gases in the water, etc. Various AGIs may have sensors to take in all sorts of “alien” data that the human language may not be equipped for. I am imagining a scenario in which a superintelligence discretizes the world in seemingly arbitrary ways, or maybe following a hidden logic based on its objective function.
Language is already optimized by humans to modularize reality into this nice clean way. Perception already filtered through language is by definition human interpretable so the deck is already largely stacked in our favor. You allude to this with your point photographers, dancers, etc developing their own language to describe subtle patterns in perception that the average human does not have language for. Wine connoisseurs develop vocabulary to discretize complex wine-tasting percepts into words like “bouquet” and “mouth-feel.” Make-up artists coin new vocabulary for around contouring, highlighting, cutting the crease, etc to describe subtle artistry that may be imperceptible to the average human.
I can imagine a hypothetical sci-fi scenario where the only jobs available are apprenticing yourself to a foundation model at a young age for life, deeply understanding its world model, and communicating its unique and alien world model to the very human realm of your local community (maybe through developing jargon or dialect, or even through some kind of art, like poetry, or dance, communication forms humans currently use to bypass the limitations of language).
Self-supervised vision models like DINO are free of a lot of human biases but may not have as interpretable of a world model as CLIP, which is co-optimized with language. I believe DINO’s lack of language bias to be either a safety issue or a superpower, depending on the context (safety in that we may not understand this “alien” world model, but superpower in that DINO may be freer from human biases that may be, in many contexts, unwanted!).
As a toy example, in this post, the above vision transformer classifies the children paying with the lion as “abaya.” This is an ethnically biased classification, but the ViT only has 1k ImageNet concepts. The limits of its dictionary are quite literally the limits of its world (in a Wittgenstein sense)! But there are so many other concepts we can create to describe the image!
Text-perception manifolds
Earlier, I mentioned that English is currently the way our culture happens to discretize reality, and there may be other coherent ways to discretize the same reality.
Consider the scene of a fruit bowl on a table. You can start asking questions such as, How many ways are there to carve up this scene into language? How many ways can we describe this fruit bowl in English? In all human languages, including languages that don’t have the concepts of fruit or bowls? In all possible languages? (which takes us to Chomsky). These question have a real analysis flavor to them, in that you’re mapping continuous perception to discrete language (yes, perception is represented discretely on a computer, but there may be advantages to framing this in a continuous way). This manifold may be very useful in navigating alignment problems.
For example, there was a certain diffusion model that would always generate salads in conjunction with women due to the spurious correlation. One question I’m deeply interested in: is there a way to represent the model’s text-to-perception world model as a manifold, and then modify it? Can you then modify this manifold to decorrelate women and salad?
A text-image manifold formalization could further answer useful questions about granularity. For example, consider a man holding an object, where object can map to anything from a teddy bear to a gun. By representing the mapping between the text/semantics of the word "object" and the perceptual space of teddy bears, guns, and other pixel blobs that humans might label as objects as a manifold, we could capture the model's language-to-perception world model in a formal mathematical structure.
—
The above two points are currently just intuitions pending formalization. I have a draft post on why I’m so drawn to non-language interp for these reasons, which I can share soon.
Noted, and thank you for flagging. I mostly agree, and do not have much to add (as we seem mostly in agreement that diverse, bluesky research is good), other than this may shape the way I present this project going forward.
However, it feels pretty odd to me to describe branching out into other modalities as crucial when we haven't yet really done anything useful with mechanistic interpretability in any domain or for any task.
I think the objective of interpretability research is to demystify the mechanisms of AI models, and not pushing the boundaries in terms of achieving tangible results / state of the art performance (I do think that interpretability research indirectly contributes in pushing the boundaries as well, because we'd design better architectures, and train the models in a better way as we understand them better). I see it being very crucial, especially as we delve into models with emergent abilities. For instance, the phenomenon of in-context learning by language models used to be considered a black box, now it has been explained through interpretability efforts. This progress is not trivial; it lays the groundwork for safer and more aligned AI systems by ensuring we have a clearer grasp of how these models make decisions and adapt.
I also think there are key differences in how these architectures function across modalities, such as attention being causal for language, while being bidirectional for vision, and how even though tokens and image patches are analogous, they are consumed and processed differently, and so much more. These subtle differences change how the models operate across modalities even though the underlying architecture is the same, and this is exactly what necessitates the mech interp efforts across modalities.
For instance, the phenomenon of in-context learning by language models used to be considered a black box, now it has been explained through interpretability efforts.
Has it? I'm quite skeptical. (Separately, only a small fraction of the efforts you're talking about are well described as mech interp or would require tools like these.)
I don't really want to get into this, just registering my skeptism. I'm quite familiar with the in-context learning and induction heads work assuming that's what you're refering to.
+1 that I'm still fairly confused about in context learning, induction heads seem like a big part of the story but we're still confused about those too!
I agree that in-context learning is not entirely explainable yet, but we're not completely in the dark about it. We have some understanding and direction or explainability regarding where this ability might stem from, and it's only going to get much clearer from here.
I think the objective of interpretability research is to demystify the mechanisms of AI models, and not pushing the boundaries in terms of achieving tangible results
Insofar as the objective of intepretability research was to do something useful (e.g. detect misalignment or remove it in extremely powerful future AI systems), I think it should also aim to be useful for solving some problems that feel roughly analogous now. Most useful things can be empirical demonstrated to accomplish specific things better than existing methods, at least in test beds.
(Noteably, I think things well described as "interpretability" often pass this bar, but I think "mech interp" basically never does. I'm using the definition of mech interp from here.)
To be clear, it's fine if mech interp isn't there yet. Many things haven't demonstrated any clear useful applications and can still be worth investing in.
For more discussion see here.
I greatly appreciated the time invested in coding the interactive demos, they help clarifying the insights into the underlying concepts – it reminds me of Colah's posts.
Questions:
Right now, there's a lot to exploit with CLIP and ViTs so that will be the focus for awhile. We may expand to Flamingo or other models if there is demand.
Other modalities would be fascinating. I imagine they have their own idiosyncrasies. I would be interested in audio in the future but not at the expense of first exploiting vision.
Ideally, yes; a unified interp framework for any modality is the north star. I do think this will be a community effort. Research in language built off findings from many different groups and institutions. Vision and other modalities are currently just not in the same place.
Join our Discord here.
This article was written by Sonia Joseph, in collaboration with Neel Nanda, and incubated in Blake Richards’s lab at Mila and in the MATS community. Thank you to the Prisma core contributors, including Praneet Suresh, Rob Graham, and Yash Vadi.
Full acknowledgements of contributors are at the end. I am grateful to my collaborators for their guidance and feedback.
Outline
Introducing the Prisma Library for Multimodal Mechanistic Interpretability
I am excited to share with the mechanistic interpretability and alignment communities a project I’ve been working on for the last few months. Prisma is a multimodal mechanistic interpretability library based on TransformerLens, currently supporting vanilla vision transformers (ViTs) and their vision-text counterparts CLIP.
With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.
The goal of Prisma is to make research in mechanistic interpretability for multimodal models both easy and fun. We are also building a strong and collaborative open source research community around Prisma. You can join our Discord here.
This post includes a brief overview of the library, fleshes out some concrete problems, and gives steps for people to get started.
Prisma Goals
Tutorial Notebooks
To get started, you can check out three tutorial notebooks that show how Prisma works.
Brief ViT Overview
A vision transformer (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.
Unlike language models, vision transformers do not have a dictionary-style embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, akin to word embeddings in language models. A learnable class token (CLS token) is prepended at the start of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.
The patch embeddings then pass through the transformer blocks (each block consists of a LayerNorm, an Attention layer, another LayerNorm, and an MLP layer). The output of each block is added back to the previous input. The sum of the block’s output and its previous input is called the residual stream.
The final layer of this vision transformer is a classification head with 1000 logit values for ImageNet's 1000 classes. The CLS token is fed into the final layer for 1000-way classification. Adapting TransformerLens, we designed HookedViT to easily capture intermediate activations with custom hook functions, instead of dealing with PyTorch's normal hook functionality.
Prisma Functionality
We’ll demonstrate the functionality with some preliminary research results. The plots are all interactive but the LW site does not let me render HTML. See the original post for interactive graphs.
Emoji Logit Lens
The emoji logit lens is a convenient way to visualize patch-level predictions for each layer of the net.
We treat every patch like the CLS token, and feed it into the ViT’s 1000-way classification head that’s pre-trained on ImageNet, without fine-tuning. This is the equivalent to deleting all layers between the layer of your choice and the output classification head.
For convenience, we represent the ImageNet prediction of that patch with its corresponding emoji, drawing from our ImageNet-Emoji Dictionary.
Below are the patch-level predictions of the final layer of a ViT for an image of a cat sitting inside a toilet. The yellow means that the logit prediction was high and blue means the logit prediction was low (see the Emoji Logit Lens notebook for more details).
Emergent Segmentation
One of my favorite findings so far is that the patch-level logit lens on the image basically acts as a segmentation map. For the image above, the cat patches get classified as cat, and the toilet patches get classified as a toilet!
This is not an obvious result, as vision transformers are optimized to predict a single class with the CLS token, and not segment the image. The segmentation is an emergent property. See the Emoji Logit Lens Notebook for more details and an interactive visualization.
Similar emergent segmentation capabilities were recently reported by Gandelsman, Efros, and Steinhardt (2024), who found that decomposing CLIP's image representation across spatial locations allowed obtaining zero-shot semantic segmentation masks that outperformed prior methods. Our results extend this finding to vanilla vision transformers and provide an intuitive visualization using the emoji logit lens.
We can see similar results on other images.
(Note: For visualization purposes, I’ve changed the coloring to be by emoji class instead of logit value like above; see Emoji Logit Lens notebook for details.)
Interestingly, the net has some biased predictions (“abaya” for the children, perhaps due to their ethnicity), one consequence of only having a 1000-class vocabulary to span concept-space.
Funnily, the net thinks that the center of the green apple (above image, bottom left) is a bagel.
When we do a layer-by-layer logit lens, we see the net’s evolving predictions:
Interestingly, the net picks up on the “animal” at 9_pre (the residual stream before the 9th transformer block) but classifies the cat as a dog. The net only catches onto the cat at 10_pre.
This layer-wise analysis builds upon the work of Gandelsman, Efros, and Steinhardt (2024), who used mean ablations to identify which layers in CLIP have the most significant direct effect on the final representation. Our emoji logit lens provides a complementary view, visualizing how the patch-level predictions evolve across the model's depth.
Interactive code here.
We can also visualize the evolving per-patch predictions for the above cat/toilet image for all the layers at once:
Direct Logit Attribution
The library supports direct logit attribution, including at the layer-level and attention-level.
Below, the net starts making a distinction between tabby/collie and banana at the eighth layer. See the ViT Prisma Main Demo for the interactive graph.
Attention Heads
I wrote an interactive JavaScript visualizer so we can see what each vision attention head is attending to on the image.
The x and y axes of the attention head are the flattened image. The image is 50 patches in total, including the CLS token, which means the total attention head is a 50x50 square.
Upon initial inspection, the first layer’s attention heads are extremely geometric.
Corner Head, Edges Head, and Modulus Head
We can see attention heads’ scores specializing for specific patterns in the data, including what we call a Corner Head, an Edges Head, and a Modulus Head. This is fascinating because the flattened image does not explicitly contain corner, edge, or row/column information; detecting these patterns is emergent from training.
These findings echo the recent work of Gandelsman, Efros, and Steinhardt (2024) who identified property-specific attention heads in CLIP that specialize in concepts like colors, locations, and shapes. Our results suggest that such specialization is a more general property of vision transformer architectures, including vanilla models trained solely on image classification, and includes even more basic geometric properties like the coordinates of the image.
Interactive code here.
Video of the Corner Head
Activation Patching
Prisma has the activation patching functionality of TransformerLens.
The Cat-Dog Switch
I found a single attention head (Layer 11, Head 4) wherein patching the CLS token of the z-matrix flips the computation from tabby cat to Border Collie. The CLS token in that z-matrix aggregates patch-level cat ear/face information from the attention pattern.
Our activation patching results demonstrate that this technique can be used to flip the model's prediction by targeting specific heads, providing a powerful tool for understanding and manipulating the model's decision-making process
This result resonates with Gandelsman, Efros, and Steinhardt (2024), who showed that knowledge of head-specific roles in CLIP can be used to manually intervene in the model's computation, such as removing heads associated with spurious cues.
Interactive code here.
Toy Vision Transformers
We are releasing nine tiny ViTs for testing (equivalent to TransformerLens’ gelu-1l) to better isolate behavior. These tiny ViTs were trained by Yash Vadi and Praneet Suresh.
The repo also contains training code to quickly train custom toy ViTs.
HookedViT
We currently support timm’s vanilla ViTs, TinyCLIP, the video vision transformer, and our own custom tiny transformers. More models will come soon based on demand!
FAQ
Is multimodal mechanistic interpretability really that different from language?
Yes and no. Vision mech interpretability is like language mechanistic interpretability, but in a fun-house mirror. Both architectures are transformers, so many LLM techniques carry over. However, there are a few twists:
If there is demand, I may write up a post giving a deeper and more theoretical take on the differences on language vs non-language mechanistic interpretability.
Why start with vision transformers?
Vision transformers have an extremely similar architecture to language transformers, so many of the existing techniques transfer over cleanly.
Diffusion models are the next obvious frontier, but there will be a larger conceptual leap in designing mechanistic interpretability techniques, largely due to their iterative denoising process. I’d be happy to collaborate on this with anyone who is serious about building strong conceptual foundations here.
Getting Started with Vision Mechanistic Interpretability
How to get involved
Open Problems in Vision Mechanistic Interpretability
Here are some Open Problems to get started. If inspired, you are encouraged to post your own in the comments, or comment on the ideas that most grab your attention.
Easy and Exploratory
Expanding Techniques to New Architectures and Datasets
Deeper Investigations
Advanced Investigations
Acknowledgements
Thank you to this most excellent mosaic of communities.
Thank you to Praneet Suresh, Rob Graham, and Yash Vadi, and the other core contributors to the Prisma Repo. Thank you to my PI, Blake Richards, and the rest of our lab at Mila for their support and feedback.
Thank you to Neel Nanda for guidance in bringing mechanistic interpretability to another modality, to Joseph Bloom for your advice on building a repo, to Arthur Conmy for coining the term “dogit lens,” and to the rest of the MATS community for your feedback.
Thank you to the Prisma group at Mila for your feedback, including Santoshi Ravichandran, Ali Kuwajerwala, Mats L. Richter, and Luca Scimeca; members of LiNCLab, including Arna Ghosh and Dan Levenstein; and members of CERC-AAI lab, including Irina Rish and Ethan Caballero. Thank you to Karolis Ramanauskas, Noah MacCallum, Rob Graham, and Romeo Valentin for your feedback on the tutorial notebooks.
Finally, thank you to the South Park Commons community for your support, including Ker Lee Yap, Abhay Kashyap, Jonathan Brebner, and Ruchi Sanghvi and Aditya Agarwal.
This research was generously supported by Blake Richards’s lab, which was funded by the Bank of Montreal; NSERC (Discovery Grant: RGPIN-2020-05105; Discovery Accelerator Supplement: RGPAS-2020-00031; Arthur B. McDonald Fellowship: 566355-2022); CIFAR (Canada AI Chair; Learning in Machine and Brains Fellowship); and a Canada Excellence Research Chair Award to Prof. Irina Rish; and by South Park Commons. This research was enabled in part by support provided by Calcul Québec and the Digital Research Alliance of Canada. We acknowledge the material support of NVIDIA in the form of computational resources.