Cool project. I'm not sure if it's interesting to me for alignment, since it's such a toy model. What do you think would change when trying to do similar interpretability on less-toy models? What would change about finding adversarial examples? Directly intervening on features seems like it might stay the same though.
I'm not sure if it's interesting to me for alignment, since it's such a toy model.
Cruxes here are things like whether you think toy models are governed by the same rules as larger models, whether studying them helps you understand those general principles and whether understanding those principles is valuable. This model in particular shares many similarities in architecture and training to LLMs and is over a million parameters so it's not nearly as much of a toy model as others and we have particular reasons to expect insights to transfer (both transformers / next token predictors).
What do you think would change when trying to do similar interpretability on less-toy models?
The recipe stays mostly the same but scale increases and you know less about the training distribution.
What would change about finding adversarial examples?
This is a very complicated/broad question. There's a number of ways you could approach this. I'd probably look at identifying critical features in the language model and see whether we can develop automatic techniques for flipping them. This could be done recursively if you are able to find the features most important for those features (etc.). Understanding why existing adversaries like jail-breaking techniques / initial affirmative responses work (mechanistically) might tell us a lot about how to automate more general search for adversaries. However, my guess is that the task of finding adversaries using a white-box approaches may be fairly tractable. The search space is much smaller once you know features and there are many search strategies that might work to flip features (possibly working recursively through features in each layer and guided by some regularization designed to keep adversaries naturalistic/plausible.
Directly intervening on features seems like it might stay the same though.
This doesn't seem super obvious if features aren't orthogonal, or may exist in subspaces or manifolds rather than individual directions. The fact that this transition isn't trivial is one reason it would be better to understand some simple models very well (so that when we go to larger models, we're on surer scientific footing).
The agent's context includes the reward-to-go, state (i.e, an observation of the agent's view of the world) and action taken for nine timesteps. So, R1, S1, A1, .... R9, S9, A9. (Figure 2 explains this a bit more) If the agent hasn't made nine steps yet, some of the S's are blank. So S5 is the state at the fifth timestep. Why is this important?
If the agent has made four steps so far, S5 is the initial state, which lets it see the instruction. Four is the number of steps it takes to reach the corridor where the agent has to make the decision to go left or right. This is the key decision for the agent to make, and the agent only sees the instruction at S5, so S5 is important for this reason.
Figure 1 visually shows this process - the static images in this figure show possible S5's, whereas S9 is animation_frame=4 in the GIF - it's fast, so it's hard to see, but it's the step before the agent turns.
The first frame, apologies. This is a detail of how we number trajectories that I've tried to avoid dealing with in this post. We left pad in a context windows of 10 timesteps so the first observation frame is S5. I've updated the text not to refer to S5.
Keywords: Mechanistic Interpretability, Adversarial Examples, GridWorlds, Activation Engineering
This is part 2 of A Mechanistic Interpretability Analysis of a GridWorld Agent Simulator
Links: Repository, Model/Training, Task.
Epistemic status: I think the basic results are pretty solid, but I’m less sure about how these results relate to broader phenomena such as superposition or other modalities such as language models. I've erred on the side of discussing connections with other investigations to make it more obvious how gridworld decision transformers may be useful.
TLDR
We analyse the embedding space of a gridworld decision transformer, showing that it has developed an extensive structure that reflects the properties of the model, the gridworld environment and the task. We can identify linear feature representations for task-relevant concepts and show the distribution of these features in the embedding space. We use these insights to predict several adversarial inputs (observations with “distractor” items) that trick the model about what it is seeing. We show that these adversaries work as effectively as changing the feature (in the environment). However, we can also intervene directly on the underlying linear feature representation to achieve the same effects. Whilst methodologically simple, this analysis shows that mechanistic investigation of gridworld models is tractable and touches on many different areas of fundamental mechanistic interpretability research and its application to AI alignment.
We recommend reading the following sections for readers short on time:
Key Results
Object Level Results
Broader Connections
While this post summarises relatively few experiments on just one model, we find our results connect with many other ideas which go beyond the details of just this model.
Introduction
Why study GridWorld Decision Transformers?
Decision Transformers are a form of offline RL (reinforcement learning) that enables us to use Transformers to solve traditional RL tasks. While traditional “online” RL trains a model to receive reward by completing a task, offline RL is analogous to language model training with the model being rewarded for predicting the next token.
Decision Transformers are trained on recorded trajectories which are labelled with the reward achieved, Reward-to-Go (RTG). RTG is the time-discounted reward stream that the agent should be getting, i.e. if it's set close to 1 then the model will be incentivised to do well because it will be taking actions consistent with the reference class of other agents that got this reward. RTG isn’t critical to this post but will be discussed in more detail in subsequent posts.
We’re interested in gridworld decision transformers for the following reasons.
AI Alignment and the Linear Representation Hypothesis
The linear representation hypothesis proposes that the neural networks represent features of the input as directions in latent space.
This post focuses on linear representations for 3 reasons:
Diagram from Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
The MiniGrid Memory Task
MemoryDT is trained to predict actions in trajectories produced by a policy that solves the MiniGrid Memory task. In this task, the agent is spawned next to an object (a ball or a key) and is rewarded for walking to the matching object at the end of the corridor. We refer to the first object as the “instruction” and the latter two objects as the “targets”.
Figure 1 shows all four variations of the environment. Please note:
Figure 1: MiniGrid Memory Task Partial Observations. Above: All 4 Variations of the MiniGrid Memory Task as seen from the starting position. Below: A recording of high-performing trajectories.
This task is interesting for three reasons:
Figure 2 shows how the decision transformer architecture interacts with the gridworld observations and the central decision. We discuss tokenisation of the observation in the next section.
Figure 2: Decision Transformer Diagram with Gridworld Observations. R corresponds to the tokenised Reward-to-Go, and S stands in for state (replaced with O in practice; we have partial observations). A corresponds to action tokens.
MemoryDT Observation Embeddings are constructed via a Compositional Code.
To adapt the Decision Transformer architecture to gridworld tasks, we tokenise the observations using a compositional code whose components are “objects at (x,y)” or colour at (x,y). For example, Key at (2,3) will have its embedding, and so will Green (2,3), etc. Figure 3 shows example observations with important vocabulary items shown.
Figure 3. Example Observations with annotated target/instruction vocabulary items.
For each present vocabulary item, we learn an embedding vector. The token is then the sum of the embeddings for any present vocabulary items:
Where Ot is the observation embedding (which is a vector of length 256), i is the horizontal position, j is the vertical position, c is the channel (colour, object or state), and fi,j,c is the corresponding learned token embedding with the same dimension as the observation embedding. I(i,j,c) is an indicator function. For example, I(2,6,key) means that there is a key at position (2,6).
Figure 4 illustrates how the observation tokens are made of embeddings, which might themselves be made of features, which match the task-relevant concepts.
Figure 4: Diagram showing how concepts, features, vocabulary item embeddings and token embeddings are related. We learn embeddings for each vocabulary item, but the model can treat those independently or use them to represent other features if desired.
A few notes on this setup:
Results
Geometric Structure in Embedding Space
To determine whether MemoryDT has learned to represent the underlying task-relevant concepts, we start by looking at the observation embedding space.
Many embeddings have much larger L2 norms than others.
Channels likely to be activated and likely to be important to the task, such as keys/balls, appeared to have the largest norms, along with “green” and other channels that may encode useful information. Some of the largest embedding vectors corresponded to vocabulary items that were understandable and important, such as Ball (2,6), the ball as seen from the starting position, whilst others were less obvious, Ball (0,6), which shouldn’t appear unless the agent moves the ball (it can do that). Embedding vectors are initialised with l2 norms of approximately 0.32, but these vectors weren’t subject to weight decay, and some grew during training.
Figure 5: Strip Plot of L2 Norms of embedding vectors in MemoryDT’s Observation Space.
Cosine Similarity Heatmaps Reveal Geometric Structure
We initially attempted PCA / U-Map for dimensionality reduction, however, neither was particularly informative. However, we were able to borrow the concept of a clustergram from systems biology. The idea is to plot a heatmap of the adjacency matrix, in this case, the cosine similarity matrix of the embeddings and reorder rows according to a clustering algorithm. The resulting cosine similarity heatmaps (methods) were interesting with and without the reordering of rows for clustering (Figure 6).
Figure 6: Cosine Similarity Heatmap of Embeddings for Key/Ball Channels. LHS: The order of rows/columns is determined by descending order given channel, y-position, x-position. The first row is Key (0,0), the next is Key (0,1) and so forth. RHS: The order of rows/columns is determined by agglomerative clustering. Figure 6 is best understood via interactions (zooming/panning).
There were a number of possible stories which might explain the structural features observed in Figure 6. Many embeddings don’t have very high cosine similarity with any others. These embeddings with low norms weren’t updated much during training.
Two effects may be interpreted with respect to correlation or anti-correlation:
To address these ideas more directly, we plotted cosine similarity to determine whether the two vocabulary items shared the same channel (key or ball) or position (Figure 7).
Figure 7: Distribution of Cosine Similarity of pairs of embeddings/vocabulary items (limited to Key/Ball channels), filtered to have an L2 norm above 0.8.
Even though channel/position is not a perfect proxy for correlation beyond the anti-correlation induced by spatial exclusivity, Figure 7 shows some general trends better than Figure 6. Beyond potentially interesting trends (which aren’t trivial to interpret), we can see many outliers whose embedding directions relative to each other can’t easily be interpreted without reference to the training distribution.
This leads us to hypothesise that semantic similarity may also affect geometric structure. By “semantic similarity”, we mean that some vocabulary items may be related not just by when they are likely to occur but by the actions that the decision transformer should make having observed them. To provide evidence for such a hypothesis, we focus on groups of vocabulary items with particularly absolute cosine similarity and clusters. For example, we observed clusters corresponding to vocabulary items in a single channel at multiple positions, such as Keys at (0,5), (2,6) and (4,2). Interpreting these clusters was possible with reference to the training distribution, specifically looking at which positions the agent might be in when those channels are activated (Figure 8).
Figure 8: Reference Observations to assist interpretation of Feature Maps. The agent is always in position (3,6).
By combining the clusters observed in Figure 8 with the distribution of possible observations in the training dataset, it’s possible to see several semantically interpretable groups:
At this point, we hypothesised that each of these vocabulary items may contain underlying linear features corresponding to the semantic interpretation of the group.
Extracting and Interpreting Feature Directions in MemoryDT’s Observation Embeddings
To extract each feature direction, we perform feature extraction via Principal Component Analysis on the subset of relevant embedding vectors. Using PCA, we hope to throw away unimportant directions while quantifying the variance explained by the first few directions. We can attempt to interpret both the resulting geometry of the PCA and the principal component directions themselves. (see methods).
To interpret the principal component directions, we show heatmaps of the dot product between the PC and each embedding vector, arranging these values to match the corresponding positions in the visualisations of gridworld partial observations. These heatmaps, which I call “feature maps”, have much in common with heatmaps of convolutional layers in vision models and represent virtual weights between each embedding the underlying principal component. (see methods).
The Primary Instruction Feature
Figure 9: Exploratory Analysis of the “Instruction” subset of Observation Embeddings.
Left) Cosine Similarity Heatmap of the Instruction Embeddings.
Right) 2D Scatter Plot of the first 2 Principal Components of a PCA generated from the embedding subset.
Previously, we identified keys/balls at positions (4,2), (4,3), (0,5) and (2,6) as clustering and hypothesised that this may be due to an underlying “instruction feature”. The first two principal components of the PCA explain 85.12% of the variance in those embeddings and the first two dimensions create a space in which keys/balls appear in antipodal pairs (Figure 9). This projection is reminiscent of both feature splitting/anisotropic superposition (which is thought to occur when highly correlated features have similar output actions) and antipodality found in toy models.
PC1 separates keys from balls independently of position, making it a candidate for a linear representation of an instruction feature. One way to interpret this is a very simple form of equivariance, where the model detects the instruction at many different positions as the instruction.
To visualise this instruction feature, we generate a feature map for PC1 (Figure 10), which shows that this feature is present to varying degrees in embeddings for keys/balls at many different positions where the instruction might be seen. We note that the instruction feature tends to be present at similar absolute values but opposite signs between keys and balls, suggesting a broader symmetry in the instruction feature between keys and balls.
Figure 10: Feature Map showing Instruction PC1 Values for all embeddings corresponding to Keys/Ball.
Another Instruction Feature?
PC2 in the Instruction subset PCA is less easy to interpret. Figure 9 distinguishes whether the instruction has been identified from “look-back” and “starting” positions. However, it appears to “flip” the effect it has for embeddings, which correspond to “instruction is key” vs “instruction is ball”. Moreover, the feature map for PC2 (Figure 11) shows keys and balls at (3,4) as having noticeable cosine similarity with this direction, which doesn't fit that interpretation. Nor does this explanation predict that keys/balls at (4,3), a position similar to the look-back feature, barely projects onto PC2.
We suspect that PCA fails at finding a second interpretable feature direction because it finds orthogonal directions, however, it’s not obvious that there isn’t a meaningful underlying feature. We plan to investigate this further in the future.
Figure 11: Feature Map showing Instruction PC2 Values for all embeddings corresponding to Keys/Ball.
Target Features
Figure 12: Exploratory Analysis of the Target Embeddings. Left) Cosine Similarity Heatmap of the Target Embeddings. Right) 2D Scatter Plot of the first 2 Principal Components.
For the target feature, we identified two separate clusters, each made up of two sets of almost antipodal pairs (Figure 12). The geometry here is much closer to isotropic superposition/toy model results. The faint-checkerboard pattern suggests the slightest hint of a more general target feature, which we suspect may be learned if we trained MemoryDT for long enough.
The first two principal components of the resulting PCA explain 83.69% of the variance in those embeddings and produced interpretable feature maps (Figure 13):
Figure 13: Feature map showing Instruction PC1 Values (above) and PC2 embedding (below) for all embeddings corresponding to Keys/Ball.
Using Embedding Arithmetic to Reverse Detected Features
Embedding Arithmetic with the Instruction Feature
We previously observed that the group of vocabulary items associated with the instruction concept were separated cleanly into Keys and Balls by a single principal component explaining 60% of the total variance associated with the 6 vectors included. From this, we hypothesised that this principal component reflects an underlying “instruction feature”. To validate this interpretation, we want to show that we can leverage this prediction in non-trivial ways such as by generating feature-level adversaries (as previously applied to factors found via dictionary learning in language models and copy/paste attacks in image models)
Based on the previous result, we predicted that if we added two vocabulary items matching the opposite instruction (ie: if the instruction is a key, seen at (2,6), we can add a ball to (0,5) and a ball to (4,2)) and this would induce the model to behave as if the instruction were flipped. I’ve drawn a diagram below to explain the concept (Figure 14).
Figure 14: Diagram showing the basic inspiration behind “instruction adversaries”.
Effectiveness of Instruction Feature Adversaries
Figure 15: Animation showing the trajectory associated with the Instruction Reversal Experiment.
To test the adversarial features / embedding arithmetic hypothesis, we generated a set of prompts/trajectories ending in a position where the model’s action preference is directly determined by observing the instruction being a key/ball (Figure 15). For each of the target/instruction configurations in Figure 15, we generate five different edits (Figure 14) to the first frame:
Note that due to the tokenisation of the observation, we can think of adding these vocabulary items to the input as adding adversarial features.
Note: I’m using the word “complement” because if the original instruction was a key, add a ball to reverse it and vice versa.
Figure 16: Adversarial Observation Token Variations. Added objects are shown in red though only the object embedding is added.
Figure 17 shows us the restored logit difference for each of the three test cases Complement (0,5), Complement (4,2) and Complement (0,5), (4,2) using the original frame as our negative control or “clean” input and Instruction Flipped as our “corrupt”/positive control.
Figure 17: Restored Logit Difference between left/right for instruction feature adversaries in "scenario 1"(MemoryDT). 8 facet images correspond to each target, instruction and RTG combination. (RTG = 0.892 corresponds to the highest possible reward that an optimal policy would receive. RTG = 0 corresponds to no reward, often achieved by going to the wrong target)
These results are quite exciting! We were able to predict very particular adversaries in the training data that would cause the model to behave (almost) as if it had seen the opposite instruction and did so from the feature map (an interpretability tool).
Let’s break the results in Figure 17 down further:
Proving that Instruction Feature Adversaries operate only via the Instruction Feature.
Whilst the previous results are encouraging, we would like to provide stronger evidence behind the notion that the projection of the embedding space into instruction feature direction is causally responsible for changing the output logits. To show this we provide two lines of evidence:
The adversarial inputs are changing the presence of the instruction feature.
For each of the forward passes in the experiment above, we plot the dot product of the instruction feature with the observation embedding against the difference between the logits for turning left and right (Figure 16). We see that:
Figure 18: Measuring the projection of the first frame observation embeddings into the Instruction PC0 direction (x-axis) and showing the logit difference between left/right (y-axis).
Direct Interventions on the Instruction Feature
We directly intervene on the instruction feature in each scenario tested above, again plotting the logit difference for the final left minus right direction (Figure 19).
This similarity in the functions mapped by the adversarial intervention (Figure 18) and the direct intervention is striking! They show a similar (and clearer) functional mapping from the instruction feature sign/magnitude to the logit difference, suggesting the instruction feature entirely explains our adversarial results.
Figure 19: Intervened Instruction PC0 direction (x-axis) and showing the logit difference between left/right (y-axis).
Do the Instruction Feature Adversaries Transfer?
Finally, since our explanation of the instruction feature suggests that it represents a meaningful property of the data and that our embedding arithmetic can be interpreted as adversaries, it is reasonable to test if those adversaries transfer to another model trained on the same data. MemoryDT-GatedMLP is a variant of MemoryDT that is vulnerable to the same adversarial features (Figure 20).
Figure 20: Restored Logit Difference between left/right for instruction feature adversaries. MemoryDT-GatedMLP (RTG = 0.892).
Figure 18 suggests the following:
Since MemoryDT-Gated MLP is a fairly similar model to MemoryDT, it’s not surprising that the adversaries transfer; however it fits with existing theories regarding feature universality and adversarial attacks are not bugs, they are features.
Discussion
Feature Representations in GridWorld Observation Embeddings
There are several ways to explain our results and connect them to previous work. It’s not surprising to see structure in our embeddings since highly structured embeddings have been previously linked to generalisation and grokking in toy models, and the presence of composable linear features in token embeddings has been known for a long time.
Moreover, a fairly simple story can be told to explain many of our observations:
However, many pertinent questions remain unanswered:
Adversarial Inputs
To validate our understanding of the instruction feature, we used both adversarial inputs and direct intervention on the instruction feature. We could correctly predict which embeddings could be used to trick the model and show that this effect was mediated entirely via the feature we identified.
Adversaries and Interpretability
In general, our results support previous arguments that the study of interpretability and adversaries are inseparable. Various prior results connect adversarial robustness to interpretability, and it’s been claimed that “More interpretable networks are more adversarially robust and more adversarially robust networks are more interpretable”.
Applying the claim here, we could say that MemoryDT is not adversarially robust; therefore, we should not expect it to be interpretable. However, this seems to be false. Rather, MemoryDT used a coherent, interpretable strategy to detect the instruction from lower-level features operating well in-distribution but making it vulnerable to feature-level adversarial attacks. Moreover, to be robust to the adversaries we designed, and still perform well on the original training distribution, MemoryDT would need to implement more complicated circuits that would be less interpretable.
We’re therefore more inclined to interpret these results as weak support for the claim that interpretability, even once we’ve defined it rigorously, may not have a monotonic relationship with properties like adversarial robustness or generalisation. The implications of this idea for scaling interpretability have been discussed informally here.
Adversaries and Superposition
There are many reasons to think that adversaries are not bugs, they are features. However, it has been suggested that vulnerability to adversarial examples may be explained by superposition. The argument suggests that unrelated features in superposition can be adversarially perturbed, confusing the model, which would fit into the general category of adversaries as bugs.
However, this was suggested in the context of isotropic superposition, not anisotropic superposition. Isotropic superposition involves feature directions which aren’t representing similar underlying objects sharing dimensions, whilst anisotropic superposition may involve features that “produce similar actions” (or represent related underlying features).
There are three mechanisms through which antipodal or anisotropic superposition might be related to adversaries:
It’s easy to theorise, but we’re excited about testing mechanistic theories of LM jailbreaking techniques. Moreover, we’re also excited to see whether hypotheses developed when studying gridworld models generalise to language models.
Adversaries and Activation Addition
A method was recently proposed to steering language model generation via steering vectors via arithmetic in activation space. However, similar previous methods found steering vectors via stochastic gradient descent. The use of counterbalanced steering vectors is justified by the need to emphasise some properties in which two prompts or tokens differ. The vectors is then further “emphasised” via a scaling factor that can affect steering performance.
We propose that the results in this analysis may be highly relevant to the study of steering vectors in two ways:
We don’t claim these insights are novel, but the connections seem salient to us. Thus, we’re interested in seeing whether further experiments with latent interventions in gridworld models can teach us more about steering vectors.
Conclusion and Future Work
This is a quote that summarises my (Joseph’s) sentiment about this post and working on MemoryDT.
There are limitations to studying a single model and so it’s important to be suspicious of generalising statements. There is still a lot of work to do on MemoryDT so connecting this work to broader claims is possibly pre-emptive.
Despite this, we think the connections between this and other work speaks to an increasingly well defined and better connected set of investigations into model internals. The collective work of many contributors permits a common set of concepts that relate phenomena across models and justifies a diverse portfolio of projects, applied and theoretical, on small and larger models alike.
We’re excited to continue to analyse MemoryDT and other gridworld models but also to find ways of generating and testing hypotheses which may apply more broadly.
Our primary aims moving forward with this analysis are to:
However, we’re also interested in continuing to explore the following topics:
Glossary
Gratitude
This work was supported by grants from the Long Term Future Fund, as well as the Manifund Regranting program. I’d also like to thank Trajan house for hosting me. I’m thankful to Jay Bailey for joining me on this project and all his contributions.
I'd like to thank all those who gave feedback on the draft including (in no particular order) Oskar Hollinsworth, Curt Tigges, Lucy Farnik, Callum McDougall, David Udell, Bilal Chughtai, Paul Colognese and Rusheb Shah.
Appendix
Methods
Identifying Related Embeddings with Cosine Similarity Heatmaps
Even though we had fairly strong prior expectations over which sets of vocabulary items were likely to be related to each other, we needed a method for pulling out these groups of embeddings in an unbiased fashion. They are more useful when clustered, so we use scikit-learn to perform agglomerative clustering based on a single linkage with Euclidean distance. This is just a fancy method for finding similar groups of tokens.
This works quite effectively for these embeddings but likely would be insufficient in the case of a language model. Only the largest underlying feature (if any) would determine the nearest points and so you would struggle to retrieve meaningful clusters. A probing strategy or use of sparse autoencoders to find features followed by measuring token similarity with those features might be better in that case.
Principal Component Analysis on a Subset of Embeddings for Feature Identification
Clustering heatmaps aren’t useful for understanding geometry unless they have very few vectors, so we make use of Principal Component Analysis for this instead. Principal Component Analysis is a statistical technique that constructs an orthonormal basis from the directions of maximum variance within a vector space and has been applied previously to study word embeddings and latent spaces in conjunction with circuit analysis (in both cases applied to a subset of possible vectors).
It turns out that PCA is very useful for showing feature geometry in this case for the following reasons:
There are two issues with PCA:
To interpret principal components, we project them onto the embedding space for relevant channels (mainly keys/balls) and then show the resulting scores arranged in a grid with the same shape as the observations generated by the MiniGrid Environment. It’s possible to interpret these by referring to the positions where different vocabulary items sit and which concepts they represent.
Interpreting Feature Directions with Feature Maps
Once we have a direction that we believe corresponds to a meaningful feature, we can take the cosine similarity between this direction and every element of embedding space. Since the embedding space is inherently structured as a 7*7 grid with 20 channels, we can simply look at the embeddings for the relevant channels (keys and balls). This is similar to a convolution with height/width and as many channels as the embedding dimension.
Feature maps are similar to the heat maps produced by Neel in his investigation into OthelloGPT, using probe directions where we used embeddings and the residual stream where we used our feature.
Validating Identified Features by Embedding Arithmetic
To test whether a linear feature representation corresponds to a feature, we could intervene directly on the feature, removing or adding it from the observation token, but we can also simply add or subtract vocabulary items that contain that feature.
Our method is similar to the activation addition technique, which operates on the residual stream at a token position but works at the input level. If we operated directly on the hypothesised linear feature representation direction, then this method would be similar to the causal intervention on the world model used on OthelloGPT to test whether a probe vector could be used to intervene in a transformer world representation.
To evaluate the effect of each possible embedding arithmetic, we take the modal scenario where the model has walked forward four times and is choosing between left/right. We measure the logit difference between left and right in the following contexts:
Then, for each test case, we report the proportion of logit difference restored (LD(test) - LD(negative control )) / (LD(positive control) - LD(negative control )).
This is identical to the metric we would use if evaluating the effect size of a patching experiment and while it hides some of the variability in the results, it also makes the trends very obvious.
Related Work
Decision Transformers
Decision Transformers are one of several methods developed to apply transformers to RL tasks. These methods are referred to as “offline” since the transformer learns to from a corpus of recorded trajectories. Decision Transformers are conditioned to predict actions consistent with a given reward because they are “goal conditioned” receiving a token representing remaining reward to be achieved at each timestep. The decision transformer architecture is the basis for SOTA models developed by DeepMind including Gato (a highly generalist agent) and Robocat (A foundation agent for robotics).
GridWorld Decision Transformers
Earlier this year we studied a small gridworld decision transformer mainly via attribution and ablations. More recently, I posted details about MemoryDT, the model discussed in this post.
Circuit-Style Interpretability
A large body of previous work exists attempting to understand the inner structures of deep neural networks. Focussing on the most relevant work to this investigation, we attempt to find features/linear feature representations by framing the circuit style interpretability. We refer to previously documented phenomena such as equivariance, isotropic superposition (previously “superposition”) and recently documented anisotropic superposition. Our use of PCA was inspired by its application to key/query and value subspaces in the 70B Chinchilla Model analysis but PCA has a much longer history of application to making sense of neural networks.
Linear Representations
Linear algebraic structure has been previously predicted in word embeddings and found using techniques such as dictionary learning and sparse autoencoders. Such representations can be understood as suggesting that the underlying token embedding is a sum of “word factors” or features.
Taken from Zhang et al 2021
More recently, efforts have been made to find linear feature representations in the residual stream with techniques such as dictionary learning, sparse auto-encoders or sparse linear probing. What started as an attempt to understand how language models deal with polysemy (the property of a word/token having more than one distinct meaning) has continued as a much more ambitious attempt to understand how language models represent information in all layers.
RL Interpretability
A variety of previous investigations have applied interpretability techniques to models solving RL tasks.
Convolutional Neural Networks: This includes analysis of a convolutional neural network solving the Procgen CoinRun task using attribution and model editing. Similarly, a series of investigations into models that solve the procgen Maze task identified a subset of channels responsible for identifying the target location that could be retargeted (a limited version of retargeting the search.)
Transformers: An investigation by Li et al. found evidence for a non-linear world representation in an offline-RL agent playing Othello using probes. It was later found that this world representation was linear and amenable to causal interventions.
Antipodal Representations
Toy models of superposition were found to use antipodal directions to represent anti-correlated features in opposite directions. There is some evidence that and we’ve seen that language models also make use of antipodal representations.
Adversarial Inputs
Adversarial examples are important to both interpetability and AI safety. A relevant debate is whether these are bugs or features (with our work suggesting the latter), though possibly the topic should be approached with significant nuance.
Activation Additions/Steering Vectors
We discuss activation addition as equivalent to our embedding arithmetic (due to our observation tokenization schema). Activation additions attempt steering language model generation underpinned by paired, counterbalanced vectors in activation space. Similar steering approaches have been developed previously finding directions with stochastic gradient descent. Of particular note, one investigation used an internal direction representing truth to steer model generation.