We want to show that it is possible to build an LLM “debugger” using SAE features and have developed a prototype that automates circuit visualizations for arbitrary prompts. With a few improvements to existing techniques (notably, “cluster resampling”, which is a form of activation patching), we are able to produce visually interpretable results. Here’s an annotated circuit processing the letters in the word “thinks”:
Predicting a space character after the letters in “thinks”
Our main contributions with this work are:
With a “debugger”, we automate sparse circuit visualizations for arbitrary prompts. Our work uses a small 216k-parameter model, but we hope to scale up the technique to larger LLMs.
We develop a novel resampling technique, “cluster resampling”, that infers the value of omitted features through clustering inputs from a model’s training dataset. This allows us to eliminate redundancy and half the size of our circuits.
We find and explain new high-frequency latents, latents with magnitude-dependent interpretations, and latents context-dependent interpretations
Methods
Model & SAE Training
We use a small character-based LLM trained on the Tiny Shakespeare dataset. Our model follows the GPT-2 architecture, with 4 layers and an embedding dimension of 64. We train end-to-end JumpReLU SAEs after every residual stream layer, and on the embeddings. For each layer, we’re able to achieve a downstream cross-entropy loss increase of under 0.05 at a sparsity of L0 ≈ 10. Our code is available here.
Sparse autoencoder locations
Circuit Extraction
Overview
With the exception of our variation on resample ablation, we use conventional techniques to identify the SAE latents that need to be included in our circuit. For each layer, we start with a target KL divergence threshold, then iteratively ablate latents as long as the KL divergence stays below our threshold. Once all SAE latents have been selected, we again use resampling ablation to compute the weights of the edges in our graph. Our primary contributions relate to “cluster resampling”, which, for our model, is what enables the extraction of circuits from arbitrary prompts and omission of redundant information.
Circuit Nodes
For any given prompt, we identify the nodes in our circuit using the following steps:
A circuit initially includes all SAE latents. Then, starting from the embedding layer, we iteratively remove latents to find the smallest set of latents that yields a KL divergence below a threshold. We compute the KL divergence of the patched model run (patching the current layer only[1]) with respect to the original outputs.
If our starting KL divergence exceeds the target threshold, we iteratively remove latents anyways in search of nodes that neutrally impact KL divergence.
After selecting SAE latents for the embedding layer, we continue the process in the next layer. We only consider latents at token positions that were part of the circuit in upstream layers (i.e.: we always ablate all downstream layers of a token position if we completely ablate it in an upstream layer).
Within the SAE latents of each layer, we identify some individual latents that, when ablated, actually bring the model outputs closer to the original output compared to when all latents are present. This effect occurs using both zero ablation and cluster resampling and seems stronger than expected by chance. We plan on investigating this phenomenon in future work. The effect is reminiscent of “negative heads” (Conmy et al. 2023, McDoughall et al. 2023), which have been observed for logit difference-based circuits, but is unexpected for KL divergence-based circuits.
Circuit Edges
Once the nodes of a circuit are identified, we assign weights to the edges within our circuit using the following steps:
For each latent in our circuit, we identify all upstream latents that could causally influence its value (i.e. the current and all previous token positions).
We then resample each upstream latent and measure its effect on each downstream latent’s value. We set the edge weight to the MSE between its original value, and the resampled value, averaged over multiple samples because the resampling is stochastic.
Note that the lines in our circuit visualizations connect groups of features to upstream “blocks”. Internally, we use edge weights between pairs of features and between pairs of blocks.[2] The edge weights between pairs of features are used to stylize the features within a block whenever a downstream or upstream feature is selected. The edge weights between pairs of blocks are used to inform the default weighting of the lines in our visualization whenever no selection is made.
Cluster Resampling
We introduce a new resampling (ablation) method that we call “cluster resampling”. We address an issue present in existing zero- / mean- / resample-ablation techniques: These techniques often preserve redundant latents. We find that latents are often correlated, and that a small number of latents can be used to infer the remaining important latents. Reducing circuits to this small number of latents creates more interpretable visualizations without, in our cluster visualization, omitting relevant information.
Our method evaluates whether, given a (small) set of latents and their magnitudes (at a specific token position and layer), can we infer the remaining latents with sufficient accuracy such that patched activations reconstructed using the specified and inferred latents reproduces the model output (as measured using KL divergence).
We start by constructing an n-dimensional space spanned by the specified latents.
Using cached latent values, map all tokens in the training dataset to a point in this high-dimensional space.
Apply k-nearest neighbor clustering to cluster this data, and identify the cluster corresponding to the current input.[3]
Resample the values of all ablated (incl. “inferred”) latents from other data points in this cluster and compute the KL divergence of the outputs.
Average ablation results over multiple samples from the cluster.
Resampling Using In-Distribution Values
From the values of latents 1 and 2, we can narrow down the range of possible values for latent 3 and 4.
Two viewpoints help us understand why cluster resampling is effective:
As mentioned earlier, some latents may be correlated. While multiple latents may be necessary for reconstructing the activations, they may encode redundant information. This allows us to reduce the number of latents we need to specify.
Cluster resampling can be seen as a refinement of resample ablation, ensuring that the final activation or set of latents remains in-distribution for the model. We generate a resampling distribution conditioned upon the active latents and their values.
Comparison of Ablation Methods
As a result of using cluster resampling, the circuits we extract include fewer features than those extracted using conventional ablation methods. In the following experiment, we compare the number of features in circuits extracted using a zero ablation, resample ablation, and cluster resample ablation. We use random prompts from the validation dataset and generate circuits with a target KL divergence threshold of 0.25. Because our model activations embed a large amount of positional information (reviewed in the next section), conventional methods tend to produce poor circuits on lengthier sequences. To not unfairly advantage our method we limit this experiment to short sequences with exactly three tokens.
Number of Features
The midline in each box plot represents the median number of features in 100 sample circuits.
KL Divergence
Each datapoint represents the average KL divergence per layer after ablating / resampling.
Cluster resampling produces the best results, followed by zero ablation, then distantly by conventional resampling.[4] Beyond being smaller, circuits produced through cluster resampling yield sparser cross-token interactions. In the table below, we link to a web interface displaying the circuits generated by each method for samples representing the start of simple words. We also include an example from a sequence with sixteen tokens demonstrating why our experiment uses shorter sequences.
Example Circuits Extracted Using Each Type of Ablation Method
This cluster resampling circuit contains 20 features and has a mean KL div of 0.20 / layer.
This zero ablation circuit contains 59 features and has a mean KL div of 0.18 / layer.
Positional Information
Because GPT-2 uses positional embeddings in the residual stream, our SAE latents also encode that information. This results in a large number of latents and interactions apparently dedicated to passing of positional information. While this is important for the model, we prefer to focus on non-positional information.
To minimize the occurrence of these features within our circuits, we cluster our samples using an added dimension representing a token’s position in a sequence. This allows resample ablation to recover positional information without needing to include these positional features in the circuit. This added positional dimension is most impactful at the embedding layer but diminishes in importance downstream. We do not use this positional dimension for selecting features in the final layer.
We opted for this clustering modification instead of a strict "resample only from the same token position" rule, as we found that the latter excessively reduces the number of available samples. Many of our SAE latents seem to incorporate some positional information but primarily represent non-positional features.
To test the effectiveness of incorporating positional information, we repeat our prior experiment and include a variation on cluster resampling that omits use of this added positional dimension. On average, the circuits extracted through use of this added dimension include about 4 fewer features. Qualitatively, we’ve seen that this effect intensifies with sequence length.
Number of features by resampling variation
Although use of positional information improves the sparsity of circuits extracted from GPT-2-based models, we do not expect it to be useful when extracting circuits from models using modern architectures. Most LLMs now use some form of rotational position embeddings, which means that the absolute token position should no longer be encoded using SAE latent values. However, this technique may still aid in studying modern vision transformers, many of which largely rely on absolute position embeddings.
Feature Interpretation
Circuit visualization allows us to study the internal representations used by an LLM and identify patterns in latent activations. We found some patterns in our SAEs that you might find interesting!
High-Frequency Features Represented Using Scalars
Some SAE latents encode continuous features as scalars, and we find examples of these in our circuits. The following latents encode information about the length of a word (first) and a token’s absolute position within a sequence (second).
Layer 2, Latent 243 ≈ 1.3 Word length featureLayer 1, Latent 244 ≈ 2.9 Sequence position feature
These screenshots are taken from our application, which displays tokens in sequences sorted by activation similarity. For each sequence, we highlight the closest matching token and show a heatmap below the other tokens to provide additional context.[6]
Features with Different Interpretations at Different Magnitudes
Most of the latents included within our circuits seem to represent features with different interpretations at different magnitudes. Like many features others have observed, these features initially appear monosemantic under a superficial evaluation; however, closer examination reveals more nuanced behaviors.
Feature 181 in layer 4 appears related to the letters in a character’s name. At its 95th percentile value, this feature usually represents the double-line break before the character’s name. But at its 50th percentile value, this feature fires on letters within a character’s name.
Feature 621 in layer 3 usually appears whenever the letter “y” is used to spell a word. At its 95th percentile value, this feature represents the word “my”, but at its 50th percentile value, the feature instead becomes more sensitive to words beginning with a “t” or containing an “h”.
Layer 3, Latent 621 ≈ 0.6 Letter “y” in word, usually following “t” / “h”Layer 3, Latent 621 ≈ 1.2 “My” or “my”
Features Interpreted Using Group Context
Some features we identify seem to rely on group context. In isolation, these features appear uninterpretable, but their roles become more clear when analyzed alongside other features. Feature 182 in layer 2 appears polysemantic in isolation. However, its combination with feature 120 (“word contains an l”) unambiguously responds to “l” as the 2nd letter of a word.
Layer 2, Latent 120 ≈ 5.7 Word contains an “l”Layer 2, Latent 182 ≈ 0.1 Word starts with “p” or has “l” as 2nd letterTogether Word has “l” as its 2nd letter
Feature 37 in layer 4 seems to predict the letter “S” as the next token in a sequence, but, at feature magnitude of 0.6, its prediction occasionally includes false positives. In contrast, feature 116 in layer 4 is highly polysemantic – so much so that no clear pattern emerges from in its activations. We do notice, however, that this feature seems to fire on the “U” at the end of certain names. When combined, their activations form a highly targeted pattern, consistently identifying “U” in the name “ESCALUS”.
Layer 4, Latent 37 ≈ 0.6 Next letter is probably an “S”Layer 4, Latent 116 ≈ 0.6 Triggered by the “U” at the end of certain names.Together “U” in “ESCALUS”
Applications
Debugging
Circuit visualization allows us to discover potential sources of errant behavior through inspecting internal representations. Take for example an inspection of this block in the penultimate layer of a circuit predicting the “ing” in “pleading”. Collectively, the features here identify words ending “ing”, however, there exist some errant entries in its list of matches. Words like “unknit” and “phoenix” are included despite being present in the training dataset, suggesting that the model associates them with the same spelling pattern. This led us to predict that our LLM would misspell these words and, as expected, it does complete these words with an “-ing” suffix.
Spelling Mistakes Discovered Through Finding an “ing” Representation
For a simple LLM such as the one being studied, there may exist other ways of evaluating its spelling capabilities; however, as we deploy this type of solution for larger models, we hope to identify the sources of error for more complex behaviors.
Model Analysis
In the penultimate layer of this circuit, the LLM seems to lean upon a learned collection of words designating royalty to predict that the sequence “The King o” should end with “of”. The collection of features in this block identify the words “King” and “Lord”. Top activations for feature 27 in layer 3 include phrases such as “Duke of”, “King of”, and “Prince of”.
Words Relating to Royalty
Latent 27 on layer 3 at a magnitude ≈ 0.855 identifies “o”s at the beginning of words, often after titles designating royalty
It may not surprise us that an LLM trained on the entirety of Shakespeare’s works would group words like “King” and “Duke”; however, as we scale up the size of models being analyzed, we hope to uncover the internal representations that give rise to more complex behaviors and source the data that produces such results.
Roadmap
A common criticism of mechanistic interpretability is that it has yet to deliver approaches that clearly outperform baseline techniques. We hope this proof-of-concept shows that we can now inspect an LLM’s internal representations with enough precision to warrant scaling such solutions. Admittedly, the model that we analyze is very small, and current limitations in SAE training techniques suggest there may be constraints on the model sizes that can be effectively studied. However, we anticipate that ongoing advancements in mechanistic interpretability will help overcome these barriers.
Feature Quality
As we scale our solution, we’ll encounter challenges preserving SAE feature sparsity and faithfulness for larger models. We likely have some headroom to scale before reaching current limits; however, expanding to models with a billion parameters while preserving similar SAE L0 and KL divergence targets will require incorporating new research.
There may also be other ways of decomposing models that allow us to represent activations more faithfully using fewer components. We’re interested in SAE training techniques that build on features extracted from upstream layers. Many circuit features appear to duplicate those found upstream. While it's visually useful to show a repeated reference, recognizing that these features correspond to identical components within a circuit could enhance interpretability (e.g. Crosscoders, Lindsey et al. 2024).
In absence of continued research process, we can still scale our solution by simply training SAE features using highly targeted datasets. If, for example, we wanted to learn more about chain-of-though reasoning as it relates to math capabilities, we could produce SAE features that are disproportionately trained on mathematical training material.
Cluster Resampling Performance
Because we use such a small model, we can take a “first principles” approach to minimizing circuit sizes. The benefits we’ve discovered are worth preserving; unfortunately, cluster resampling is slow because it requires a mean squared error calculation over cached values representing the entire training dataset. Application to production-scale models will require approximating or speeding up the technique.
There exists some obvious low hanging fruit – for one, efficient k-nearest neighbor search over a sparse matrix is an active area of research. Our application implements a simple heuristic and leaves room for improvement. If we cannot achieve certain performance requirements through code optimizations alone, then we may need to approximate cluster resampling, perhaps taking inspiration from techniques used in attribution patching. And finally, we might want to use clustering resampling as a refinement step after extracting circuits using more efficient methods. We could, for example, use zero ablation to extract our initial circuit, then apply cluster resampling to half its size and excise redundancy.
Scaling to 1 Billion
We eventually hope to apply our solution to LLMs with useful capabilities. The smallest of DeepSeek’s R1 models contains about 1.5 billion parameters, while Apple’s on-device OpenELM models start at 270 million. If we, through a series of incremental improvements, produce a solution that works for a model with around that many parameters, we’d have a “debugger” that could help tie LLM training to the familiar landscape of conventional software development. While our immediate goals include reattempting our analysis on models using similar architectures and larger configurations, our aspirations reach beyond these intermediate steps.
This project was part of the MARS program and supported by the Cambridge AI Safety Hub, which connects members of the AI safety community and sponsors work that addresses risks from use of advanced AI systems.↩︎
I.e. we don’t need “error nodes” (Marks et al. 2024) for a single layer, but we do reset the activations for the next layer to avoid compounding SAE errors.
We calculate weights between blocks using a modified version of edge extraction that ablates groups of upstream latents and measures the MSE on groups of downstream latents. ↩︎
We speculate that conventional resample ablation performs more poorly than zero ablation because it results in an above-average number of active latents (as disproportionately many inactive latents are resampled).
With conventional resampling, we were unable to achieve the target KL divergence of 0.25. The examples we have for this ablation method consist of circuits with significantly higher KL divergence.
Our application can actually sort sequences using a variety of heuristics, which include use of clustering and maximally-activating dataset examples; however, all the screenshots shown here represent similarly-activating dataset examples.
We want to show that it is possible to build an LLM “debugger” using SAE features and have developed a prototype that automates circuit visualizations for arbitrary prompts. With a few improvements to existing techniques (notably, “cluster resampling”, which is a form of activation patching), we are able to produce visually interpretable results. Here’s an annotated circuit processing the letters in the word “thinks”:
Our main contributions with this work are:
Methods
Model & SAE Training
We use a small character-based LLM trained on the Tiny Shakespeare dataset. Our model follows the GPT-2 architecture, with 4 layers and an embedding dimension of 64. We train end-to-end JumpReLU SAEs after every residual stream layer, and on the embeddings. For each layer, we’re able to achieve a downstream cross-entropy loss increase of under 0.05 at a sparsity of L0 ≈ 10. Our code is available here.
Circuit Extraction
Overview
With the exception of our variation on resample ablation, we use conventional techniques to identify the SAE latents that need to be included in our circuit. For each layer, we start with a target KL divergence threshold, then iteratively ablate latents as long as the KL divergence stays below our threshold. Once all SAE latents have been selected, we again use resampling ablation to compute the weights of the edges in our graph. Our primary contributions relate to “cluster resampling”, which, for our model, is what enables the extraction of circuits from arbitrary prompts and omission of redundant information.
Circuit Nodes
For any given prompt, we identify the nodes in our circuit using the following steps:
Within the SAE latents of each layer, we identify some individual latents that, when ablated, actually bring the model outputs closer to the original output compared to when all latents are present. This effect occurs using both zero ablation and cluster resampling and seems stronger than expected by chance. We plan on investigating this phenomenon in future work. The effect is reminiscent of “negative heads” (Conmy et al. 2023, McDoughall et al. 2023), which have been observed for logit difference-based circuits, but is unexpected for KL divergence-based circuits.
Circuit Edges
Once the nodes of a circuit are identified, we assign weights to the edges within our circuit using the following steps:
Note that the lines in our circuit visualizations connect groups of features to upstream “blocks”. Internally, we use edge weights between pairs of features and between pairs of blocks.[2] The edge weights between pairs of features are used to stylize the features within a block whenever a downstream or upstream feature is selected. The edge weights between pairs of blocks are used to inform the default weighting of the lines in our visualization whenever no selection is made.
Cluster Resampling
We introduce a new resampling (ablation) method that we call “cluster resampling”. We address an issue present in existing zero- / mean- / resample-ablation techniques: These techniques often preserve redundant latents. We find that latents are often correlated, and that a small number of latents can be used to infer the remaining important latents. Reducing circuits to this small number of latents creates more interpretable visualizations without, in our cluster visualization, omitting relevant information.
Our method evaluates whether, given a (small) set of latents and their magnitudes (at a specific token position and layer), can we infer the remaining latents with sufficient accuracy such that patched activations reconstructed using the specified and inferred latents reproduces the model output (as measured using KL divergence).
Resampling Using In-Distribution Values
Two viewpoints help us understand why cluster resampling is effective:
Comparison of Ablation Methods
As a result of using cluster resampling, the circuits we extract include fewer features than those extracted using conventional ablation methods. In the following experiment, we compare the number of features in circuits extracted using a zero ablation, resample ablation, and cluster resample ablation. We use random prompts from the validation dataset and generate circuits with a target KL divergence threshold of 0.25. Because our model activations embed a large amount of positional information (reviewed in the next section), conventional methods tend to produce poor circuits on lengthier sequences. To not unfairly advantage our method we limit this experiment to short sequences with exactly three tokens.
Number of Features
KL Divergence
Cluster resampling produces the best results, followed by zero ablation, then distantly by conventional resampling.[4] Beyond being smaller, circuits produced through cluster resampling yield sparser cross-token interactions. In the table below, we link to a web interface displaying the circuits generated by each method for samples representing the start of simple words. We also include an example from a sequence with sixteen tokens demonstrating why our experiment uses shorter sequences.
Example Circuits Extracted Using Each Type of Ablation Method
Cluster Resampling vs. Zero Ablation
Positional Information
Because GPT-2 uses positional embeddings in the residual stream, our SAE latents also encode that information. This results in a large number of latents and interactions apparently dedicated to passing of positional information. While this is important for the model, we prefer to focus on non-positional information.
To minimize the occurrence of these features within our circuits, we cluster our samples using an added dimension representing a token’s position in a sequence. This allows resample ablation to recover positional information without needing to include these positional features in the circuit. This added positional dimension is most impactful at the embedding layer but diminishes in importance downstream. We do not use this positional dimension for selecting features in the final layer.
We opted for this clustering modification instead of a strict "resample only from the same token position" rule, as we found that the latter excessively reduces the number of available samples. Many of our SAE latents seem to incorporate some positional information but primarily represent non-positional features.
To test the effectiveness of incorporating positional information, we repeat our prior experiment and include a variation on cluster resampling that omits use of this added positional dimension. On average, the circuits extracted through use of this added dimension include about 4 fewer features. Qualitatively, we’ve seen that this effect intensifies with sequence length.
Although use of positional information improves the sparsity of circuits extracted from GPT-2-based models, we do not expect it to be useful when extracting circuits from models using modern architectures. Most LLMs now use some form of rotational position embeddings, which means that the absolute token position should no longer be encoded using SAE latent values. However, this technique may still aid in studying modern vision transformers, many of which largely rely on absolute position embeddings.
Feature Interpretation
Circuit visualization allows us to study the internal representations used by an LLM and identify patterns in latent activations. We found some patterns in our SAEs that you might find interesting!
High-Frequency Features Represented Using Scalars
Some SAE latents encode continuous features as scalars, and we find examples of these in our circuits. The following latents encode information about the length of a word (first) and a token’s absolute position within a sequence (second).
Word length feature
Sequence position feature
These screenshots are taken from our application, which displays tokens in sequences sorted by activation similarity. For each sequence, we highlight the closest matching token and show a heatmap below the other tokens to provide additional context.[6]
Features with Different Interpretations at Different Magnitudes
Most of the latents included within our circuits seem to represent features with different interpretations at different magnitudes. Like many features others have observed, these features initially appear monosemantic under a superficial evaluation; however, closer examination reveals more nuanced behaviors.
Feature 181 in layer 4 appears related to the letters in a character’s name. At its 95th percentile value, this feature usually represents the double-line break before the character’s name. But at its 50th percentile value, this feature fires on letters within a character’s name.
Letters & spaces in names
Double line-breaks
Feature 621 in layer 3 usually appears whenever the letter “y” is used to spell a word. At its 95th percentile value, this feature represents the word “my”, but at its 50th percentile value, the feature instead becomes more sensitive to words beginning with a “t” or containing an “h”.
Letter “y” in word, usually following “t” / “h”
“My” or “my”
Features Interpreted Using Group Context
Some features we identify seem to rely on group context. In isolation, these features appear uninterpretable, but their roles become more clear when analyzed alongside other features. Feature 182 in layer 2 appears polysemantic in isolation. However, its combination with feature 120 (“word contains an l”) unambiguously responds to “l” as the 2nd letter of a word.
Word contains an “l”
Word starts with “p” or has “l” as 2nd letter
Word has “l” as its 2nd letter
Feature 37 in layer 4 seems to predict the letter “S” as the next token in a sequence, but, at feature magnitude of 0.6, its prediction occasionally includes false positives. In contrast, feature 116 in layer 4 is highly polysemantic – so much so that no clear pattern emerges from in its activations. We do notice, however, that this feature seems to fire on the “U” at the end of certain names. When combined, their activations form a highly targeted pattern, consistently identifying “U” in the name “ESCALUS”.
Next letter is probably an “S”
Triggered by the “U” at the end of certain names.
“U” in “ESCALUS”
Applications
Debugging
Circuit visualization allows us to discover potential sources of errant behavior through inspecting internal representations. Take for example an inspection of this block in the penultimate layer of a circuit predicting the “ing” in “pleading”. Collectively, the features here identify words ending “ing”, however, there exist some errant entries in its list of matches. Words like “unknit” and “phoenix” are included despite being present in the training dataset, suggesting that the model associates them with the same spelling pattern. This led us to predict that our LLM would misspell these words and, as expected, it does complete these words with an “-ing” suffix.
Spelling Mistakes Discovered Through Finding an “ing” Representation
For a simple LLM such as the one being studied, there may exist other ways of evaluating its spelling capabilities; however, as we deploy this type of solution for larger models, we hope to identify the sources of error for more complex behaviors.
Model Analysis
In the penultimate layer of this circuit, the LLM seems to lean upon a learned collection of words designating royalty to predict that the sequence “The King o” should end with “of”. The collection of features in this block identify the words “King” and “Lord”. Top activations for feature 27 in layer 3 include phrases such as “Duke of”, “King of”, and “Prince of”.
Words Relating to Royalty
It may not surprise us that an LLM trained on the entirety of Shakespeare’s works would group words like “King” and “Duke”; however, as we scale up the size of models being analyzed, we hope to uncover the internal representations that give rise to more complex behaviors and source the data that produces such results.
Roadmap
A common criticism of mechanistic interpretability is that it has yet to deliver approaches that clearly outperform baseline techniques. We hope this proof-of-concept shows that we can now inspect an LLM’s internal representations with enough precision to warrant scaling such solutions. Admittedly, the model that we analyze is very small, and current limitations in SAE training techniques suggest there may be constraints on the model sizes that can be effectively studied. However, we anticipate that ongoing advancements in mechanistic interpretability will help overcome these barriers.
Feature Quality
As we scale our solution, we’ll encounter challenges preserving SAE feature sparsity and faithfulness for larger models. We likely have some headroom to scale before reaching current limits; however, expanding to models with a billion parameters while preserving similar SAE L0 and KL divergence targets will require incorporating new research.
There may also be other ways of decomposing models that allow us to represent activations more faithfully using fewer components. We’re interested in SAE training techniques that build on features extracted from upstream layers. Many circuit features appear to duplicate those found upstream. While it's visually useful to show a repeated reference, recognizing that these features correspond to identical components within a circuit could enhance interpretability (e.g. Crosscoders, Lindsey et al. 2024).
In absence of continued research process, we can still scale our solution by simply training SAE features using highly targeted datasets. If, for example, we wanted to learn more about chain-of-though reasoning as it relates to math capabilities, we could produce SAE features that are disproportionately trained on mathematical training material.
Cluster Resampling Performance
Because we use such a small model, we can take a “first principles” approach to minimizing circuit sizes. The benefits we’ve discovered are worth preserving; unfortunately, cluster resampling is slow because it requires a mean squared error calculation over cached values representing the entire training dataset. Application to production-scale models will require approximating or speeding up the technique.
There exists some obvious low hanging fruit – for one, efficient k-nearest neighbor search over a sparse matrix is an active area of research. Our application implements a simple heuristic and leaves room for improvement. If we cannot achieve certain performance requirements through code optimizations alone, then we may need to approximate cluster resampling, perhaps taking inspiration from techniques used in attribution patching. And finally, we might want to use clustering resampling as a refinement step after extracting circuits using more efficient methods. We could, for example, use zero ablation to extract our initial circuit, then apply cluster resampling to half its size and excise redundancy.
Scaling to 1 Billion
We eventually hope to apply our solution to LLMs with useful capabilities. The smallest of DeepSeek’s R1 models contains about 1.5 billion parameters, while Apple’s on-device OpenELM models start at 270 million. If we, through a series of incremental improvements, produce a solution that works for a model with around that many parameters, we’d have a “debugger” that could help tie LLM training to the familiar landscape of conventional software development. While our immediate goals include reattempting our analysis on models using similar architectures and larger configurations, our aspirations reach beyond these intermediate steps.
This project was part of the MARS program and supported by the Cambridge AI Safety Hub, which connects members of the AI safety community and sponsors work that addresses risks from use of advanced AI systems.↩︎
I.e. we don’t need “error nodes” (Marks et al. 2024) for a single layer, but we do reset the activations for the next layer to avoid compounding SAE errors.
We calculate weights between blocks using a modified version of edge extraction that ablates groups of upstream latents and measures the MSE on groups of downstream latents. ↩︎
For our experiments, we cluster using 256 nearest neighbors, then estimate KL divergence using 64 samples drawn from this cluster.
We speculate that conventional resample ablation performs more poorly than zero ablation because it results in an above-average number of active latents (as disproportionately many inactive latents are resampled).
With conventional resampling, we were unable to achieve the target KL divergence of 0.25. The examples we have for this ablation method consist of circuits with significantly higher KL divergence.
Our application can actually sort sequences using a variety of heuristics, which include use of clustering and maximally-activating dataset examples; however, all the screenshots shown here represent similarly-activating dataset examples.