This post is an announcement for a software library. It is likely only relevant to those working, or looking to start working, in mechanistic interpretability.
What is graphpatch?
graphpatch is a Python library for activation patching on arbitrary PyTorch neural network models. It is designed to minimize the amount of boilerplate needed to run experiments making causal interventions on the intermediate activations in a model. It provides an intuitive API based on the structure of a torch.fx.Graph representation compiled automatically from the original model. For a somewhat silly example, I can make Llama play Taboo by zero-ablating its output for the token representing "Paris":
with patchable_llama.patch(
{"lm_head.output": ZeroPatch(slice=(slice(None), slice(None), 3681))}
):
print(
... (read 223 more words →)
Thanks! You’re correct that you can implement ROME with vanilla hooks, since these give you access to module inputs in addition to the outputs. But the fact that this works is contingent on both the specific interventions ROME makes and the way Llama/GPT2 happen to be implemented. To get maybe overly concrete, in this line
ROME wants the result of the multiplication, which isn’t the output of any individual submodule. You happen to be able to access it as the input of
down_proj, because that happens to be a module, but it didn’t have to be implemented this way. (This would be even worse if we wanted to patch the value... (read more)