Checking out Chris Olah's group's stuff from the last few years is probably a good place to start.
Some links:
https://transformer-circuits.pub/2022/toy_model/index.html
https://transformer-circuits.pub/2023/privileged-basis/index.html
https://transformer-circuits.pub/2023/monosemantic-features/index.html
While I think about interpretability a lot, it's not my day job! Let me dive down a rabbit hole and tell me where I am wrong.
Intro
As I see it the first step to interpretability would be isolating which neurons perform which role. For example, which neurons are representing patterns/features/entities in the input vs. which neurons are transforming data into higher dimensions as to be more easily separable.
Of course, distinguishing neurons by role opens its own can of worms (what are the other roles? What about neurons with dual roles? etc.) And we have the common problems of: what if a neuron is representing a part of a feature? What if a neuron only in concert with other neurons represents some entity?
Regardless I would love to know about some methods of isolating the neurons which are specifically recognizing and representing features vs. any other role. Am I missing something in the literature?
My Hypothesis - near linearity is an identifiable quality of neurons in the representational role.
Just as the structure of a heart, or a leaf, or an engine belies its purpose, I believe there ought to be some identifiable indication that a neuron represents some external entity rather than acting in some other role.
My initial guess is that neurons that primarily serve a representational role tend to exhibit a more linear relationship between their inputs and outputs. This observation aligns with the intuitive understanding that neurons that specialize in representing entities focus on directly identifying or classifying features present in the input data. Their output I surmise is often a near-linear combination of these features.
On the other hand, neurons that transform data are more concerned with altering the topological or dimensional properties of the input space. These neurons often introduce complexities and nonlinearities as they project data into higher-dimensional or more abstract spaces. This transformation aids subsequent layers in the network but makes the relationship between their inputs and outputs less straightforward, and nonlinear.
I would also guess that there are neurons whose role is simply to pass-along information to further layers. This complicates things, because I would guess that this pass-along role also produces near-linear combinations of input features. Though I suggest we can isolate these neurons, by noting that these neurons tend to focus on one input dimension and ignore the rest.
So, I hypothesize that we could start measuring how linear a neurons output is in relation to its inputs and I would hope that this gives us a (messy) indication of which neurons are representing features rather than performing some other role.
My guess of how to measure which neurons which recognize features.
I would consider using some combination of the following measurements to find the near linearity of neurons.
1. Linearity Score: One of the likely initial screening methods involves calculating a linearity score for each neuron using statistical techniques like linear regression. A high R2 value typically indicates a neuron's output is linearly dependent on its input, making it a candidate for a representational role within the network.
2. Feature Importance via Saliency Maps: For a more nuanced analysis, especially in convolutional networks handling image data, saliency maps could be employed. These maps visually represent which parts of the input a neuron focuses on. Neurons that focus on specific, easily interpretable features can be classified as recognition neurons and used for validation of other methods.
3. Mutual Information: This measure can be highly informative and is applicable to both classification and regression tasks. It quantitatively captures the relationship between a neuron's output and the original input or the output label, helping to distinguish between transformation and recognition roles.
4. Jacobian Matrix: For multi-dimensional inputs, the Jacobian offers a rigorous way to examine how small changes in each input dimension affect the neuron's output. This method can generalize well but might be computationally expensive for high-dimensional data.
5. Eigenvalue Spectrum: Analyzing the eigenvalues of the weight matrix associated with each neuron might offer clues into its role. However, this might be computationally intensive for large networks.
6. Entropy Measures: The entropy of a neuron's output can be a robust way to gauge its complexity. Higher entropy might indicate that the neuron is adding complexity to the data, suggesting a role in the transformation role. Low entropy would indicate that the neuron is within the representational role.
Any thoughts on this? Are people working on something similar? Can you point me to some papers? While I think about interpretability a lot, it's not my day job! What am I missing?