I trained OthelloGPT-small, a 600k param model that predicts legal next moves in 6x6 Othello with 99.3% accuracy
I trained attention SAEs and transcoders on every module output with 94%+ loss recovered
I used the learned features and the static model weights to generate a computational graph of the model based on a linear upstream relevance metric
I analysed the computational graph for u_A1 and found decision trees that built up compound board state patterns as well as an inhibitory feature that seemed to perform a non-linear operation
This provides evidence that patch-free methods can produce interpretable computational graphs at scale
Setup
OthelloGPT-small
OthelloGPT is a language model that predicts the next (legal) move in strings representing Othello games. It was first proposed in 2022 by Kenneth Li et al. Some great posts to read on the topic are:
Non-linear probes can extract the board state at each text position
The probes can be used to instruct interventions on the residual stream that change the model's internal board state representation, evidenced by corresponding changes in the model's predictions
SAEs can be applied to transformer module outputs that rediscover all previous probes in an unsupervised manner
Approximate latent attributions can be used to find modular circuits
I trained my own OthelloGPT model on the smaller 6x6 variant of the game, on the premise that a quadratic reduction in vocab size and context length could reduce inference times byO(n6) while maintaining sufficient task complexity. I managed to get the model size down to 600k params (3 layers, 128 dims) with 99.3% top-1 accuracy in predicting legal next moves, compared to the original model with 25m params (8 layers, 512 dims) and 99.99% accuracy.
Upstream relevance metric
Next, I trained attention SAEs on each attention module and transcoders on each LayerNorm[1] and MLP module, with 94%+ loss recovered in all cases.
Evaluation metrics for the 12 "SAEs" trained on OthelloGPT-small
The goal of learning features like this was to trace an end-to-end computational graph of the features, working backwards recursively from logits to tokens. Each learned feature f was represented by a 2-d node in the graph, with the x-coordinate indexing the feature groups (0 = embed vectors, 1 = layer 0 LN1, 2 = layer 0 attn_z, 3 = layer 0 LN2, …, 12 = layer 2 MLP, 13 = unembed vectors) and the y-coordinate indexing the (alive) features[2] within each group.
From a node f, a node u was considered upstream if xu<xf and either:
f was a LayerNorm node and u was not; or
f was an attn_z/MLP node and u was the immediately preceding LayerNorm node.
I then defined a relevance metric between a target node f and an upstream node u that approximated the degree to which the upstream node was relevant in computing the output of the target node.
For an f corresponding to an MLP node, relevance R(f,u)=fe⋅ud, where fe was the normalised encoder vector for f, and ud the normalised decoder vector for u. This was equivalent to the pullback value defined in previous work.
For an attn_z node f, relevance R(f,u)=fd⋅v, where v=udWV was the WV-projection of the upstream decoder vector into v-space. The intuition here was that any z-vector was a linear composition of v-vectors, and if the upstream u features fully decomposed the incoming residual stream, then the projection of these features should fully decompose v-space and, in turn, z-space.
LN nodes were treated similarly to MLP nodes, with the caveat that upstream attn_z nodes u had their output z-vectors WO-projected into o-vectors o=udWO before calculating R(f,u)=fe⋅o.
This allowed me to attribute a relevance metric to each upstream feature with respect to any target feature in the model. If R(f,u)=0, then the upstream feature was irrelevant to the computation of the target feature, and I hypothesised that edges high in absolute-relevance represented important/interpretable computations in the model.
This approach differed from most current techniques for circuit-tracing, which were based on activations and patching, requiring lots of data and compute. By using only the model’s weights and pre-trained feature dictionaries, I could build an instant snapshot of the model's potential capabilities. I hoped that this would provide a cheap, faithful representation of the model's complete computational graph.
Case Study: u_A1
I traced a relevance tree starting from the unembed node for A1, taking the top-1 positive and top-1 negative upstream features by relevance metric at each non-linear node (LN1/LN2/MLP). For attention nodes, I took the top-2 features by relevance, as the attention mechanism doesn’t project negative features. Each node represented a feature, each directed edge represented the relevance of an upstream feature, and the tree was traversed right-to-left, from logits to tokens.
A computational tree for u_A1, traced using the upstream relevance metric (b/u = embed/unembed, LN1/LN2 = LayerNorm, a = attn, m = MLP, f = feature, h = head) Positional embedding nodes were omitted to avoid cluttering
Starting with a depth-first analysis, I traced a path (green, above) following the edge with the highest absolute relevance value, starting from u_A1, and plotted the average board state over which each feature activated on a sample dataset, weighted by activation strength.
Weighted average board states for each feature (l = legal, e = empty, tm = theirs/mine, p = pos) Blue = 1, red = 0
This path seemed to represent a case for A1 being "legal" based on A1 being "empty", A2 being "theirs" and A3 being "mine". There was a clear branching point where the nodes up to m1f145 (read: "feature #145 of the transcoder for the MLP in layer #1) built up the A1-A2-A3 board state representation and then LN2l2f256 (read: "feature #256 of the transcoder for the 2nd LayerNorm in layer #2") constructed a union feature that activated across all "A1 legal" boards in general. The penultimate node m2f683 also activated on "A1 legal" boards, except when predicting the final move, but wrote out to the residual stream in the negative u_A1 direction, representing an inhibitory feature!
Union feature
Generally, it seems that transformers work by accumulating contextual information at each token position, constructing increasingly rich representations across layers. Accordingly, it follows that OthelloGPT might best exploit its highly parallel structure by building up representations of compound board states, representing patterns. This idea is supported by previous work on Othello and Chess. However, to my knowledge, this is the first work that decomposes the layer-by-layer structures responsible for constructing these pattern features.
I saw evidence of pattern construction behaviour in the relevance tree for node LN2l2f256. Taking its top-2 positive and top-2 negative upstream features, I collected co-activation statistics, following Balcells et al., representing conditional activation probabilities between each feature. I displayed these alongside the feature collinearity matrix and activation-weighted average board states.
An interpretability dashboard of the (pruned) decision tree for node LN2l2f256, predicting "A1 legal" using upstream features
I saw that the two most relevant upstream features corresponded to the A1-A2-A3 pattern from earlier (m1f145) as well as a symmetrical A1-B1-C1 pattern (m1f154, fittingly!). There were also features which provided evidence against A1 being legal, due to the square being occupied (m1f759 and a0f416h6).
Weighted average board states for node m1f145 and its upstream features
I stepped down another level in the tree into m1f154 and saw how the A1-A2-A3 pattern was computed using 3 pieces of upstream evidence, and this iterative breakdown into simpler features continued until reaching the embedding layer. This represented a probabilistic, evidence-based approach to pattern feature construction, where high-confidence upstream features, such as whether a particular square was empty, were combined to form compound downstream features.
It should be possible to confirm these modular circuits for pattern recognition using activation patching in future work. Having access to the complete computational graph should be helpful in informing more targeted interventions, as previous work has shown that the model seems more resilient to interventions at earlier layers, which would be expected from the hypothesised mechanism.
Inhibitory feature
I also looked into the inhibitory feature at m2f683, which seemed to activate on "A1 legal" boards with the sole purpose of reducing the A1 logit, except when predicting the final move. None of the other high-relevance features for u_A1 (below) exhibited this behaviour, but when I traced graphs for other unembedding nodes, I consistently found similar features.
An interpretability dashboard of the (pruned) decision tree for u_A1
The common characteristics of the inhibitory feature fu were:
High negative cosine similarity with the unembed vector: fu⋅u<−0.95
Strong predictive strength: p(legal|fuactive)>0.5
No activations when predicting the final move
I hypothesised that the purpose of the feature was to normalise the model output into a uniform distribution during the midgame, which typically contained a higher number of legal moves. This could be possible, for example, if there existed a feature that counted the number of legal moves, allowing the MLP to adjust its logit values accordingly. More likely, the model could have learned prior statistics based on the token position. Either way, this would represent a non-linear operation: the u_A1 direction was being used to adjust non-A1, orthogonal token probabilities.
Next Steps
This post was a preliminary report on my findings from applying attention SAEs and transcoders together to create an end-to-end computational graph of a transformer model.
What I've done here is essentially an update on previous patch-free dictionary learning work using more up-to-date SAE variants. I'm excited about the applicability of patch-free techniques due to their scalability: they're cheap enough that some evals could be folded into cross-layer SAE training loops.
For example, relevance can be thought of as a proxy for computational monosemanticity: the absolute relevance values for the inhibitory features were much higher than that of the union features because they were attuned to a singular input direction (legality) instead of multiple evidence features. On a side note, this may present interesting implications on feature geometry, e.g. the relevance values for m1f145 and m1f154, representing symmetrical patterns, to LN2l2f256 were roughly the same.
This direction of work shifts focus from evaluating isolated features to evaluating feature circuits. The application to OthelloGPT is interesting because board games are more readily evaluable. Unlike natural language, there are a tractably finite number of circuits in Othello, so it should be possible to evaluate the completeness of a circuit representation of OthelloGPT.
I think it's better to approach mech interp as a "circuit discovery" task rather than a "feature discovery" task. In isolation, it's never clear whether a given feature is at the right level of granularity. But when you can see how a feature fits into a circuit, things get a bit clearer. Along these lines, I'd also like to look into cross-layer analyses such as community detection and feature evolution in future work.
Thanks for reading. I'm new to mech interp and am very keen to get involved with the community - working in isolation sucks! In particular, I'd appreciate advice on research direction, as I don't have a particularly great sense of what would be most useful to work on, and writing style. All opinions welcome.
The LayerNorm Transcoders took inputs from hook_resid_pre (LN1) or hook_resid_mid (LN2) and trained to predict ln1.hook_normalized or ln2.hook_normalized, respectively. I'm not making any strong claims that LN Transcoders help or even make sense, but anecdotally they seemed to improve the cosine similarity of features across layers and it allowed me to hand wave away the "oh but LayerNorm isn't linear" caveat.
For attn_z features, the SAEs were trained on the concatenation of the n_head, d_head-dim z-vectors. I then reshaped these outputs to form (d_sae x n_head), d_head-dim feature nodes. Arguably this wasn’t necessary and probably weakened the relevance scores for attn_z features that were computed across several heads.
tl;dr
Setup
OthelloGPT-small
OthelloGPT is a language model that predicts the next (legal) move in strings representing Othello games. It was first proposed in 2022 by Kenneth Li et al. Some great posts to read on the topic are:
I trained my own OthelloGPT model on the smaller 6x6 variant of the game, on the premise that a quadratic reduction in vocab size and context length could reduce inference times by O(n6) while maintaining sufficient task complexity. I managed to get the model size down to 600k params (3 layers, 128 dims) with 99.3% top-1 accuracy in predicting legal next moves, compared to the original model with 25m params (8 layers, 512 dims) and 99.99% accuracy.
Upstream relevance metric
Next, I trained attention SAEs on each attention module and transcoders on each LayerNorm[1] and MLP module, with 94%+ loss recovered in all cases.
The goal of learning features like this was to trace an end-to-end computational graph of the features, working backwards recursively from logits to tokens. Each learned feature f was represented by a 2-d node in the graph, with the x-coordinate indexing the feature groups (0 = embed vectors, 1 = layer 0 LN1, 2 = layer 0 attn_z, 3 = layer 0 LN2, …, 12 = layer 2 MLP, 13 = unembed vectors) and the y-coordinate indexing the (alive) features[2] within each group.
From a node f, a node u was considered upstream if xu<xf and either:
I then defined a relevance metric between a target node f and an upstream node u that approximated the degree to which the upstream node was relevant in computing the output of the target node.
For an f corresponding to an MLP node, relevance R(f,u)=fe⋅ud, where fe was the normalised encoder vector for f, and ud the normalised decoder vector for u. This was equivalent to the pullback value defined in previous work.
For an attn_z node f, relevance R(f,u)=fd⋅v, where v=udWV was the WV-projection of the upstream decoder vector into v-space. The intuition here was that any z-vector was a linear composition of v-vectors, and if the upstream u features fully decomposed the incoming residual stream, then the projection of these features should fully decompose v-space and, in turn, z-space.
LN nodes were treated similarly to MLP nodes, with the caveat that upstream attn_z nodes u had their output z-vectors WO-projected into o-vectors o=udWO before calculating R(f,u)=fe⋅o.
This allowed me to attribute a relevance metric to each upstream feature with respect to any target feature in the model. If R(f,u)=0, then the upstream feature was irrelevant to the computation of the target feature, and I hypothesised that edges high in absolute-relevance represented important/interpretable computations in the model.
This approach differed from most current techniques for circuit-tracing, which were based on activations and patching, requiring lots of data and compute. By using only the model’s weights and pre-trained feature dictionaries, I could build an instant snapshot of the model's potential capabilities. I hoped that this would provide a cheap, faithful representation of the model's complete computational graph.
Case Study: u_A1
I traced a relevance tree starting from the unembed node for A1, taking the top-1 positive and top-1 negative upstream features by relevance metric at each non-linear node (LN1/LN2/MLP). For attention nodes, I took the top-2 features by relevance, as the attention mechanism doesn’t project negative features. Each node represented a feature, each directed edge represented the relevance of an upstream feature, and the tree was traversed right-to-left, from logits to tokens.
(b/u = embed/unembed, LN1/LN2 = LayerNorm, a = attn, m = MLP, f = feature, h = head)
Positional embedding nodes were omitted to avoid cluttering
Starting with a depth-first analysis, I traced a path (green, above) following the edge with the highest absolute relevance value, starting from u_A1, and plotted the average board state over which each feature activated on a sample dataset, weighted by activation strength.
(l = legal, e = empty, tm = theirs/mine, p = pos)
Blue = 1, red = 0
This path seemed to represent a case for A1 being "legal" based on A1 being "empty", A2 being "theirs" and A3 being "mine". There was a clear branching point where the nodes up to m1f145 (read: "feature #145 of the transcoder for the MLP in layer #1) built up the A1-A2-A3 board state representation and then LN2l2f256 (read: "feature #256 of the transcoder for the 2nd LayerNorm in layer #2") constructed a union feature that activated across all "A1 legal" boards in general. The penultimate node m2f683 also activated on "A1 legal" boards, except when predicting the final move, but wrote out to the residual stream in the negative u_A1 direction, representing an inhibitory feature!
Union feature
Generally, it seems that transformers work by accumulating contextual information at each token position, constructing increasingly rich representations across layers. Accordingly, it follows that OthelloGPT might best exploit its highly parallel structure by building up representations of compound board states, representing patterns. This idea is supported by previous work on Othello and Chess. However, to my knowledge, this is the first work that decomposes the layer-by-layer structures responsible for constructing these pattern features.
I saw evidence of pattern construction behaviour in the relevance tree for node LN2l2f256. Taking its top-2 positive and top-2 negative upstream features, I collected co-activation statistics, following Balcells et al., representing conditional activation probabilities between each feature. I displayed these alongside the feature collinearity matrix and activation-weighted average board states.
I saw that the two most relevant upstream features corresponded to the A1-A2-A3 pattern from earlier (m1f145) as well as a symmetrical A1-B1-C1 pattern (m1f154, fittingly!). There were also features which provided evidence against A1 being legal, due to the square being occupied (m1f759 and a0f416h6).
I stepped down another level in the tree into m1f154 and saw how the A1-A2-A3 pattern was computed using 3 pieces of upstream evidence, and this iterative breakdown into simpler features continued until reaching the embedding layer. This represented a probabilistic, evidence-based approach to pattern feature construction, where high-confidence upstream features, such as whether a particular square was empty, were combined to form compound downstream features.
It should be possible to confirm these modular circuits for pattern recognition using activation patching in future work. Having access to the complete computational graph should be helpful in informing more targeted interventions, as previous work has shown that the model seems more resilient to interventions at earlier layers, which would be expected from the hypothesised mechanism.
Inhibitory feature
I also looked into the inhibitory feature at m2f683, which seemed to activate on "A1 legal" boards with the sole purpose of reducing the A1 logit, except when predicting the final move. None of the other high-relevance features for u_A1 (below) exhibited this behaviour, but when I traced graphs for other unembedding nodes, I consistently found similar features.
The common characteristics of the inhibitory feature fu were:
I hypothesised that the purpose of the feature was to normalise the model output into a uniform distribution during the midgame, which typically contained a higher number of legal moves. This could be possible, for example, if there existed a feature that counted the number of legal moves, allowing the MLP to adjust its logit values accordingly. More likely, the model could have learned prior statistics based on the token position. Either way, this would represent a non-linear operation: the u_A1 direction was being used to adjust non-A1, orthogonal token probabilities.
Next Steps
This post was a preliminary report on my findings from applying attention SAEs and transcoders together to create an end-to-end computational graph of a transformer model.
What I've done here is essentially an update on previous patch-free dictionary learning work using more up-to-date SAE variants. I'm excited about the applicability of patch-free techniques due to their scalability: they're cheap enough that some evals could be folded into cross-layer SAE training loops.
For example, relevance can be thought of as a proxy for computational monosemanticity: the absolute relevance values for the inhibitory features were much higher than that of the union features because they were attuned to a singular input direction (legality) instead of multiple evidence features. On a side note, this may present interesting implications on feature geometry, e.g. the relevance values for m1f145 and m1f154, representing symmetrical patterns, to LN2l2f256 were roughly the same.
This direction of work shifts focus from evaluating isolated features to evaluating feature circuits. The application to OthelloGPT is interesting because board games are more readily evaluable. Unlike natural language, there are a tractably finite number of circuits in Othello, so it should be possible to evaluate the completeness of a circuit representation of OthelloGPT.
I think it's better to approach mech interp as a "circuit discovery" task rather than a "feature discovery" task. In isolation, it's never clear whether a given feature is at the right level of granularity. But when you can see how a feature fits into a circuit, things get a bit clearer. Along these lines, I'd also like to look into cross-layer analyses such as community detection and feature evolution in future work.
Thanks for reading. I'm new to mech interp and am very keen to get involved with the community - working in isolation sucks! In particular, I'd appreciate advice on research direction, as I don't have a particularly great sense of what would be most useful to work on, and writing style. All opinions welcome.
The LayerNorm Transcoders took inputs from hook_resid_pre (LN1) or hook_resid_mid (LN2) and trained to predict ln1.hook_normalized or ln2.hook_normalized, respectively. I'm not making any strong claims that LN Transcoders help or even make sense, but anecdotally they seemed to improve the cosine similarity of features across layers and it allowed me to hand wave away the "oh but LayerNorm isn't linear" caveat.
For attn_z features, the SAEs were trained on the concatenation of the n_head, d_head-dim z-vectors. I then reshaped these outputs to form (d_sae x n_head), d_head-dim feature nodes. Arguably this wasn’t necessary and probably weakened the relevance scores for attn_z features that were computed across several heads.