In this experiment, I study the behaviour of a Deep Q-Network agent with attention based architecture driving a vehicle in a simulated traffic environment. The agent learns to cross an intersection without traffic lights and under a dense traffic setting while avoiding collisions with other vehicles.
As first part of this experiment, I first train the agent using the attention based architecture. Later, I study the behaviour of the agent by applying some interpretability techniques on the trained Q-network and find that there is some evidence to show that network comprises of 3 different layers serving specific functions, namely - sensory (embedding layers), processing (attention layer) & motor (output layers).
The purpose of this experiment is to gain deeper understanding of the agent from interpretability perspective which may be used for developing safer agents in real world applications.
Introduction
With increasing usage and deployment of autonomous driving vehicles on roads, it is important that the behaviour of these autonomous agents is thoroughly tested, understood and analysed from a safety perspective.
To measure safety, one of the traditional approaches involves running large number of experiments in a simulated environment and collecting statistics on the number of failures cases. While this is certainly useful and gives an overall perspective on the failures modes of the agent, however it does not say anything about the specificity of those failures.
One may prefer to delve deeper and study why a particular agent failed and understand if that behaviour was a consequence of agent's actions or a conditioning of that environment. This calls upon applying some of the interpretability techniques on the agent's behaviour and derive more specific conclusions on what features the agent senses from the environment and what decisions it takes.
For this experiment, I study the behaviour of a trained agent by applying some interpretability techniques on the policy network of the model and share my observations and conclusions derived from the experiment.
The agent under study is only trained and deployed in a simulated environment (with enough simplifications) and is far from a real world setting and it's complexities. While this does not really represent the behaviour of the agent in the real world, I still think a study like this can be worthwhile in providing some insights on what kind of decision making process is learned by the agent and how it can be used to make agents more safer.
Now, to give some context on the problem, let's understand how
Fig 1: Agent environment loop in reinforcement learning
The environment setting contains total N=15 vehicles at any given point in time. The agent in question is controlling the green vehicle while the blue vehicles are simulated by traffic flow model which in this case is controlled by intelligent driver model. The intelligent driver model is less nuanced and lacks any complex behaviour in comparison to the agent in question. The blue vehicles are spawned at random points initially.
And below is the animated image of a trained agent crossing the intersection.
Fig 2: A recorded video from the evaluation stage showing agent successfully crossing intersection without collision
State
The joint observation of a road traffic with one agent denoted - s0 and other vehicles - N is described by a combined list of individual vehicle states:
s=(si)i∈[0,N]
where si=[xiyivxivyicos(ϕ)isin(ϕ)i]T
Individual values of each state variables are described as follows:
Feature
Description
presence
Disambiguate agents at 0 offset from non-existent agents.
x
World offset of agent vehicle or offset to agent vehicle on the x axis.
y
World offset of agent vehicle or offset to agent vehicle on the y axis.
vx
Velocity on the x axis of vehicle.
vy
Velocity on the y axis of vehicle.
heading
Heading of vehicle in radians.
cosh
Trigonometric heading of vehicle.
sinh
Trigonometric heading of vehicle.
The vehicle kinematics are described by Kinematic Bicycle Model. More on this topic can be found here
Actions & Rewards:
The agent drives the vehicle by controlling its speed chosen from a finite set of actions A = {SLOWER, NO-OP, FASTER}.
Rewards:
Reward
Action
1
Agent driving at maximum velocity
-5
On collision
0
Otherwise
Agent:
The agent used in this experiment uses a DQN algorithm with attention based architecture which was first proposed in the paper on - Social Attention for Autonomous Decision-Making in Dense Traffic [1].
For this experiment, I delve deeper on the agent's policy network since that network encodes the decision making of the agent.
Here's how the network looks like:
Layer Name
Dimensions
Ego & Others embedding layer - 0
7x64
Ego & Others embedding layer - 1
64x64
Attention Layer Query
64x64
Attention Layer Key
64x64
Attention Layer Value
64x64
Output Layer - 0
64x64
Output Layer - 1
64x64
Output Layer - predict
64x3
Embedding layers:
Fig 3: Schematic of Embedding layers of agent's Q-net
It is composed of several linear identical encoders, a stack of attention heads, and a linear decoder.
There are two embedding layers: 1. Ego embedding - dedicated to tracking features for vehicle driven by the agent itself. 2. Others embedding - dedicated to tracking features for vehicles driven by other agents.
Attention layers:
Fig 4: Schematic of attention layer of agent's Q-net
Essentially, a single query Q = [q0] and a set of keys K = [k0,...,kN] are emitted by doing linear projections on the state of the environment. Here, N is the number of vehicles including the agent's vehicle.
Output=σ(QKT/√dk)V
The outputs from all heads are finally combined with a linear layer, and the resulting tensor is then added to the residual networks.
Authors of the paper claim that an agent with the proposed attention architecture shows increased performance gains in autonomous decision making under a dense traffic setting. Their study involved comparing the performance of the agent against common architectures like FCN and CNN. The social interaction patterns with other vehicles were visualised and studied qualitatively.
Methods
As part of this experimentation, I first replicate the behaviour of the agent as described by the authors of the paper. Then, I proceed to study the agent's behaviour by borrowing some well recognised interpretability techniques in the literature like Understanding RL Vision[2] and A Mathematical framework for Transformer Circuits[3]. After collecting the observations, I derive some key insights on the behaviour of the agent and mention some interesting directions for work in future.
List of techniques applied in this experiment are as follows -
Replication of agent behaviour
Analysis on Embedding matrices
QK & OV circuit analysis of Transformer block
Study on activation patterns of an episode run
Analysis on feature importance in reference to output layer
Source code for training the agent and details on choice of hyper-parameters, model architecture parameters and extended results are in the Appendix section.
Replication of agent behaviour
First step is to replicate the studies of the paper by training and evaluating the agent in dense traffic setting having single intersection on the road without any traffic lights. According to my observations, I confirm that the agent successfully learns to cross the intersection while avoiding collisions with other vehicles in most scenarios.
Following animation shows the trained agent navigating through the intersection along with it's attention patterns for the time step when the agent decides to slow down noticing another vehicle in the way.
Notes: Above results confirm that the agent learns to navigates through the crossing avoiding collisions in most scenarios, by paying attention to the other vehicles in on the crossing. Attention to other vehicles at every time step of the episode are highlighted by thick coloured lines from green to blue vehicles.
Analysis on Embedding matrices:
First I analyse what insights I can find in the weights of embedding matrices. Since the embedding layers consist of 2 hidden layers, namely, W0 being the first layer with dimension 7x64 and W1 being the second layer with dimension 64x64. If embedding matrix weights are computed as follows:
WE=W0W1 then, WE dimension is 7x64
Checking the embedding matrix for both layers individually reveal the following:
Fig 5: Comparison of Ego embedding matrices between Untrained and Trained agents. Difference highlights that y coordinate feature is assigned large weightsFig 6: Comparison of Others embedding matrices between Untrained and Trained agents. Difference highlights that x,y coordinates, vy and sinh features are assigned large weights
Notes:
Ego embedding matrix - Higher weight values assigned on features y and cosh
Others embedding matrix - Higher weight values assigned on features x, y, vy and sinh
Heat maps show other features are assigned relatively lesser weights.
This shows us that sensory function of the model is picking up some interesting signals from the environment.
Next I analyse how these features interact with the attention layers of the network.
QK & OV circuit analysis of Transformer block
According to the research done by Anthropic team, they outline a mathematical framework of understanding attention layers.
Quoting from the paper:
Attention heads can be understood as having two largely independent computations: a QK (“query-key”) circuit which computes the attention pattern, and an OV (“output-value”) circuit which computes how each token affects the output if attended to.
Here, I study any emerging QK and OV circuit patterns in the attention layers. To study the emergence of any learned structure, I compare it with the Untrained vs Trained network and find their squared difference measures.
As seen from the code, The agent in this experiment has only 1 layer with 4 attention heads whose vectors are first slice in 4 parts, computed and then later concatenated. Hence, for simplicity, I computed QK and OV circuits for the combined matrix instead of slicing it in 4 parts.
QK=WTEWQKWE
where, WE∈{Ego,Others}
OV=WUWOVWE
where, WE∈{Ego,Others}
QK and OV circuits:
Fig 7: QK circuit computed using Ego embedding layer Left: represents no learned attention patterns Right: shows y-coordinate and velocities heavily being attended to, in 3rd row from bottom.Fig 8: QK circuit computed using Others embedding layer Left: shows no learned attention patterns Right: shows features (x, y, vy & sinh) being heavily attended to, in 3rd row from bottom.Fig 9: OV circuits between different embedding layers Left: shows no learned patterns Right: represents what action will be taken by the agent when that feature is attended to.
Furthermore, attention scores and output value matrices are shown in Fig 9a and 9b (in appendix section), show some interesting learned structural patterns.
Above figures show distinctions between the two scenarios, one where the agent is untrained and the other where the agent is trained. The QK and OV Circuits show high activations for for certain lines and areas in the heat map indicating a learned structure/pattern from the interplay between the agent, other vehicles and the environment.
Notes:
Fig 7 -
Agent is attending to it's own y-coordinate with a strong positive correlation.
While vy is being attended to with a negative correlation.
This suggests the agent uses its own velocity and heading to assess risk when making decisions.
Fig 8 -
Strong positive attention between Ego embedding y coordinate and Other vehicles x, y coordinates and vy velocity. Hinting at the possibility of computing a distance metric.
Strong negative attention between Ego embedding y coordinate feature and Others embedding presence feature.
Above indicate that the model has learned to focus on other vehicles' movement to adjust its own strategy.
Fig 9 -
Shows strong activations for velocity (vx, vy) and heading (cos_h, sin_h), meaning the model values its own speed and direction when choosing actions.
Presence and position (x, y) of other vehicles become crucial. The agent learns to consider the positions of other vehicles when deciding whether to slow down, idle, or speed up.
Study on activation patterns of an episode run:
Let's study the agent's activations per time step and intermediate attention matrices collected over the full episode run.
Here, I extract environment frames and the activations of Attention head vs Vehicle for each time step of the evaluation shows the following:
Fig 10: Frame rendered from the scene at the time step when agent decides to slow down near the intersection.Fig 11: Heat map of Attention head vs Vehicle number Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.
Notes:
Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.
Vehicle_0 is the green vehicle which is controlled by the agent and hence is always being attended to by all 4 attention heads at all time steps.
For other vehicles the activations increase as they appear closer to the intersection.
Next, I study output layers further to understand what insights I can draw from there.
Feature importance
In this section I try to understand which features does the model learn to extract from the environment state. One of the common techniques for finding feature importance is that of computing Integrated gradients. The integrated gradients give an understanding of overall importance of features.
For this scenario, I compare the integrated gradients between the Untrained vs Trained networks averaged over 30 episodes and found the following.
Fig 14: Computed integrated gradients for untrained and trained agents. Left: For untrained agent, turns out to be negligible. Right: For trained agent - presence, x, y coordinates and angles are important.
Notes:
Q-net's output layer makes final decision by looking into presence, x-y coordinates and sinh features. All the above graphs pile up more evidence to the previous notes/observations I gathered earlier.
In following section, I make some speculative interpretations about the agent. I would love to validate some of those interpretations by conducting more thorough experiments in the future. For the claims that I am not confident, I have marked them inline.
Interpretation:
Following is a walkthrough of the agent in action along with the interpretation on key observations/notes collected so far.
Step 1: Input Features
The agent observes its environment using the following features:
Ego Features
vx, vy: Ego vehicle’s velocity (speed in x and y directions).
cos_h, sin_h: Ego vehicle’s heading direction.
x, y: Ego vehicle’s position.
Others Features
presence: Indicates whether another vehicle is nearby.
vx, vy: Other vehicle’s velocity.
x, y: Other vehicle’s position.
Step 2: Attention Mechanism (QK Circuits)
The Query-Key (QK) circuits determine which features should be attended to when making a decision.
Ego Embedding (Self-awareness)
The agent queries its own vx, vy, cos_h, and sin_h to understand its motion.
It attends to (y) to determine its position in the intersection. (why x coordinate does not have high activations ? Something that I would like to find more later.)
Others Embedding (Awareness of other vehicles)
The agent queries presence to check if another vehicle is nearby.
It attends to vx, vy, x, y of other vehicles to predict their movement.
If a vehicle is approaching, the attention on presence and vx increases. (unverified claim, is model computing some distance metric ? can we verify this ?)
When does the agent choose "Slow"?
If the presence of other vehicles (presence feature) is high.
If relative velocity (vx, vy) suggests a collision risk.
Strong correlations in Others embedding QK circuit show that the agent reacts to nearby vehicles. (unverified claim, same reason as above)
When does the agent choose "Idle"?
Likely in neutral situations where no immediate action is needed.
OV circuits show that cos_h and sin_h influence the decision, meaning the agent aligns with road orientation.
When does the agent choose "Fast"?
If there is clear space ahead (low presence activation in Others embedding).
Trained OV circuits show positive activations for vx and sin_h, meaning the model prefers accelerating when aligned with the road.
Step 3: Q-Value Calculation (OV Circuits)
The Output-Value (OV) circuits determine how much each feature contributes to the Q-values for each action.
Feature Contributions to Actions:
If presence is high, the Slow action gets higher weight.
If vx of the other car is high, meaning it is moving fast toward the intersection, the agent reduces Q-values for Fast action.
If the agent’s own vx is high, but a collision is possible, the Q-value for Idle is also reduced.
Discussion
This experiment shows that a DQN agent with attention based mechanism can learn to cross a road intersection environment under a dense traffic setting with reasonable levels of safety.
Additionally, analysis on attention layers of the agent's Q-network show that there is sufficient evidence to believe that these layers learn some high level Q-policies that drive the decision making of the agent. Although, it was possible to find some high level policies, more work is needed to find how different policies combine together to form a concrete algorithm.
It was shown qualitatively with some level of confidence, that the agent learns to delegate different types of functions to it's embedding, attention and output layers. These layers learn to serve the sensory, processing and motor functions respectively.
Future work:
This experiment was limited in scope and timing (up to 4 weeks). For this reason, I chose to focus on replicating the behaviour of the agent and running various types of interpretability techniques to narrow down on a promising approach of finding exact behaviour of the agent in further research.
Following are some of the areas that can be explored in future:
Does agent compute a distance metric from the features (x, y) coordinates of the other vehicles ?
Do changing the y coordinate of intersection in the environment break the agent's decision making ? Has the agent really generalised or simply memorised ?
Enlist different policies learned by the agent on same action. Example, Slowing down - high activations of presence feature and high activations of vx feature of other vehicles. Do these two policies correlate highly for the agent or largely stay independent ?
Train model on more attention heads and layers with more episodes, repeat the experiments. Do we get any new insights ?
Acknowledgements:
My sincere thanks to this amazing community who have made Interpretability research easily accessible reachable to general public. I hope that my experiments bring some value to others and to this community. I look forward to delve deeper in this topic, any support & guidance is highly appreciated.
I would also like to thank BlueDot impact for running a 12 week online course on AI Safety fundamentals. I conducted this experiment as part of the project submission phase of this course and I am grateful to their course facilitators and their team for conducting amazing sessions and providing a comprehensive list of resources on the key topics.
Running evaluation over 10 episodes with display enabled shows high scores and successful navigation through the intersection.
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKInputSession subclass]: chose IMKInputSession_Modern
/Users/mdahra/workspace/machine-learning/rl-interp/.venv/lib/python3.12/site-packages/rl_agents/agents/deep_q_network/pytorch.py:80: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)
return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
[INFO] Episode 0 score: 8.6
[INFO] Episode 1 score: 5.5
[INFO] Episode 2 score: 3.0
[INFO] Episode 3 score: 9.6
[INFO] Episode 4 score: 8.5
[INFO] Episode 5 score: -1.0
[INFO] Episode 6 score: 9.0
[INFO] Episode 7 score: -1.0
[INFO] Episode 8 score: 6.5
[INFO] Episode 9 score: 7.6
Learned Attention scores:
Fig 7a: QK scores comparison between Untrained vs trained agent. Difference is computed by doing a square difference per term of the 2-dim tensor.Fig 7b: OV scores comparison between Untrained vs Trained agent. Difference is computed by doing a square difference per term of the 2-dim tensor.
Abstract
In this experiment, I study the behaviour of a Deep Q-Network agent with attention based architecture driving a vehicle in a simulated traffic environment. The agent learns to cross an intersection without traffic lights and under a dense traffic setting while avoiding collisions with other vehicles.
As first part of this experiment, I first train the agent using the attention based architecture. Later, I study the behaviour of the agent by applying some interpretability techniques on the trained Q-network and find that there is some evidence to show that network comprises of 3 different layers serving specific functions, namely - sensory (embedding layers), processing (attention layer) & motor (output layers).
The purpose of this experiment is to gain deeper understanding of the agent from interpretability perspective which may be used for developing safer agents in real world applications.
Introduction
With increasing usage and deployment of autonomous driving vehicles on roads, it is important that the behaviour of these autonomous agents is thoroughly tested, understood and analysed from a safety perspective.
To measure safety, one of the traditional approaches involves running large number of experiments in a simulated environment and collecting statistics on the number of failures cases. While this is certainly useful and gives an overall perspective on the failures modes of the agent, however it does not say anything about the specificity of those failures.
One may prefer to delve deeper and study why a particular agent failed and understand if that behaviour was a consequence of agent's actions or a conditioning of that environment. This calls upon applying some of the interpretability techniques on the agent's behaviour and derive more specific conclusions on what features the agent senses from the environment and what decisions it takes.
For this experiment, I study the behaviour of a trained agent by applying some interpretability techniques on the policy network of the model and share my observations and conclusions derived from the experiment.
The agent under study is only trained and deployed in a simulated environment (with enough simplifications) and is far from a real world setting and it's complexities. While this does not really represent the behaviour of the agent in the real world, I still think a study like this can be worthwhile in providing some insights on what kind of decision making process is learned by the agent and how it can be used to make agents more safer.
Now, to give some context on the problem, let's understand how
Then I will share my observations and insights.
Environment
The environment used in this experiment is an Intersection-env which is a customised gymnasium type that has agent-environment loop.
The environment setting contains total
N=15
vehicles at any given point in time.The agent in question is controlling the green vehicle while the blue vehicles are simulated by traffic flow model which in this case is controlled by intelligent driver model. The intelligent driver model is less nuanced and lacks any complex behaviour in comparison to the agent in question. The blue vehicles are spawned at random points initially.
And below is the animated image of a trained agent crossing the intersection.
State
The joint observation of a road traffic with one agent denoted -
s0
and other vehicles -N
is described by a combined list of individual vehicle states:s=(si)i∈[0,N]
where si=[xiyivxivyicos(ϕ)isin(ϕ)i]T
Individual values of each state variables are described as follows:
The vehicle kinematics are described by Kinematic Bicycle Model. More on this topic can be found here
Actions & Rewards:
The agent drives the vehicle by controlling its speed chosen from a finite set of actions
A = {SLOWER, NO-OP, FASTER}
.Rewards:
Agent:
The agent used in this experiment uses a DQN algorithm with attention based architecture which was first proposed in the paper on - Social Attention for Autonomous Decision-Making in Dense Traffic [1].
For this experiment, I delve deeper on the agent's policy network since that network encodes the decision making of the agent.
Here's how the network looks like:
Embedding layers:
It is composed of several linear identical encoders, a stack of attention heads, and a linear decoder.
There are two embedding layers:
1. Ego embedding - dedicated to tracking features for vehicle driven by the agent itself.
2. Others embedding - dedicated to tracking features for vehicles driven by other agents.
Attention layers:
Essentially, a single query Q = [q0] and a set of keys K = [k0,...,kN] are emitted by doing linear projections on the state of the environment. Here, N is the number of vehicles including the agent's vehicle.
Output=σ(QKT/√dk)V
The outputs from all heads are finally combined with a linear layer, and the resulting tensor is then added to the residual networks.
Authors of the paper claim that an agent with the proposed attention architecture shows increased performance gains in autonomous decision making under a dense traffic setting. Their study involved comparing the performance of the agent against common architectures like FCN and CNN. The social interaction patterns with other vehicles were visualised and studied qualitatively.
Methods
As part of this experimentation, I first replicate the behaviour of the agent as described by the authors of the paper. Then, I proceed to study the agent's behaviour by borrowing some well recognised interpretability techniques in the literature like Understanding RL Vision[2] and A Mathematical framework for Transformer Circuits[3].
After collecting the observations, I derive some key insights on the behaviour of the agent and mention some interesting directions for work in future.
List of techniques applied in this experiment are as follows -
Analysis on feature importance in reference to output layer
Source code for training the agent and details on choice of hyper-parameters, model architecture parameters and extended results are in the Appendix section.
Replication of agent behaviour
First step is to replicate the studies of the paper by training and evaluating the agent in dense traffic setting having single intersection on the road without any traffic lights.
According to my observations, I confirm that the agent successfully learns to cross the intersection while avoiding collisions with other vehicles in most scenarios.
Following animation shows the trained agent navigating through the intersection along with it's attention patterns for the time step when the agent decides to slow down noticing another vehicle in the way.
Notes:
Above results confirm that the agent learns to navigates through the crossing avoiding collisions in most scenarios, by paying attention to the other vehicles in on the crossing. Attention to other vehicles at every time step of the episode are highlighted by thick coloured lines from green to blue vehicles.
Analysis on Embedding matrices:
First I analyse what insights I can find in the weights of embedding matrices. Since the embedding layers consist of 2 hidden layers, namely, W0 being the first layer with dimension 7x64 and W1 being the second layer with dimension 64x64.
If embedding matrix weights are computed as follows:
WE=W0W1
then, WE dimension is 7x64
Checking the embedding matrix for both layers individually reveal the following:
Difference highlights that y coordinate feature is assigned large weights
Difference highlights that x,y coordinates, vy and sinh features are assigned large weights
Notes:
y
andcosh
x
,y
,vy
andsinh
This shows us that sensory function of the model is picking up some interesting signals from the environment.
Next I analyse how these features interact with the attention layers of the network.
QK & OV circuit analysis of Transformer block
According to the research done by Anthropic team, they outline a mathematical framework of understanding attention layers.
Quoting from the paper:
Here, I study any emerging QK and OV circuit patterns in the attention layers. To study the emergence of any learned structure, I compare it with the Untrained vs Trained network and find their squared difference measures.
As seen from the code, The agent in this experiment has only 1 layer with 4 attention heads whose vectors are first slice in 4 parts, computed and then later concatenated. Hence, for simplicity, I computed QK and OV circuits for the combined matrix instead of slicing it in 4 parts.
QK=WTEWQKWE
where, WE∈{Ego,Others}
OV=WUWOVWE
where, WE∈{Ego,Others}
QK and OV circuits:
Left: represents no learned attention patterns
Right: shows y-coordinate and velocities heavily being attended to, in 3rd row from bottom.
Left: shows no learned attention patterns
Right: shows features (x, y, vy & sinh) being heavily attended to, in 3rd row from bottom.
Left: shows no learned patterns
Right: represents what action will be taken by the agent when that feature is attended to.
Furthermore, attention scores and output value matrices are shown in Fig 9a and 9b (in appendix section), show some interesting learned structural patterns.
Above figures show distinctions between the two scenarios, one where the agent is untrained and the other where the agent is trained. The QK and OV Circuits show high activations for for certain lines and areas in the heat map indicating a learned structure/pattern from the interplay between the agent, other vehicles and the environment.
Notes:
y-coordinate
with a strong positive correlation.vy
is being attended to with a negative correlation.y
coordinate and Other vehiclesx
,y
coordinates andvy
velocity. Hinting at the possibility of computing a distance metric.y
coordinate feature and Others embeddingpresence
feature.vx
,vy
) and heading (cos_h
,sin_h
), meaning the model values its own speed and direction when choosing actions.x
,y
) of other vehicles become crucial. The agent learns to consider the positions of other vehicles when deciding whether to slow down, idle, or speed up.Study on activation patterns of an episode run:
Let's study the agent's activations per time step and intermediate attention matrices collected over the full episode run.
Here, I extract environment frames and the activations of Attention head vs Vehicle for each time step of the evaluation shows the following:
Frame rendered from the scene at the time step when agent decides to slow down near the intersection.
Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.
Notes:
Vehicle_0
is the green vehicle which is controlled by the agent and hence is always being attended to by all 4 attention heads at all time steps.Next, I study output layers further to understand what insights I can draw from there.
Feature importance
In this section I try to understand which features does the model learn to extract from the environment state. One of the common techniques for finding feature importance is that of computing Integrated gradients. The integrated gradients give an understanding of overall importance of features.
For this scenario, I compare the integrated gradients between the Untrained vs Trained networks averaged over 30 episodes and found the following.
Notes:
Q-net's output layer makes final decision by looking into
presence
,x-y
coordinates andsinh
features. All the above graphs pile up more evidence to the previous notes/observations I gathered earlier.In following section, I make some speculative interpretations about the agent. I would love to validate some of those interpretations by conducting more thorough experiments in the future. For the claims that I am not confident, I have marked them inline.
Interpretation:
Following is a walkthrough of the agent in action along with the interpretation on key observations/notes collected so far.
Step 1: Input Features
The agent observes its environment using the following features:
vx
,vy
: Ego vehicle’s velocity (speed in x and y directions).cos_h
,sin_h
: Ego vehicle’s heading direction.x
,y
: Ego vehicle’s position.presence
: Indicates whether another vehicle is nearby.vx
,vy
: Other vehicle’s velocity.x
,y
: Other vehicle’s position.Step 2: Attention Mechanism (QK Circuits)
The Query-Key (QK) circuits determine which features should be attended to when making a decision.
vx
,vy
,cos_h
, andsin_h
to understand its motion.y
) to determine its position in the intersection. (whyx
coordinate does not have high activations ? Something that I would like to find more later.)presence
to check if another vehicle is nearby.vx
,vy
,x
,y
of other vehicles to predict their movement.presence
andvx
increases. (unverified claim, is model computing some distance metric ? can we verify this ?)presence
feature) is high.vx
,vy
) suggests a collision risk.cos_h
andsin_h
influence the decision, meaning the agent aligns with road orientation.vx
andsin_h
, meaning the model prefers accelerating when aligned with the road.Step 3: Q-Value Calculation (OV Circuits)
The Output-Value (OV) circuits determine how much each feature contributes to the Q-values for each action.
Feature Contributions to Actions:
presence
is high, the Slow action gets higher weight.vx
of the other car is high, meaning it is moving fast toward the intersection, the agent reduces Q-values for Fast action.vx
is high, but a collision is possible, the Q-value for Idle is also reduced.Discussion
This experiment shows that a DQN agent with attention based mechanism can learn to cross a road intersection environment under a dense traffic setting with reasonable levels of safety.
Additionally, analysis on attention layers of the agent's Q-network show that there is sufficient evidence to believe that these layers learn some high level Q-policies that drive the decision making of the agent. Although, it was possible to find some high level policies, more work is needed to find how different policies combine together to form a concrete algorithm.
It was shown qualitatively with some level of confidence, that the agent learns to delegate different types of functions to it's embedding, attention and output layers. These layers learn to serve the sensory, processing and motor functions respectively.
Future work:
This experiment was limited in scope and timing (up to 4 weeks). For this reason, I chose to focus on replicating the behaviour of the agent and running various types of interpretability techniques to narrow down on a promising approach of finding exact behaviour of the agent in further research.
Following are some of the areas that can be explored in future:
(x, y)
coordinates of the other vehicles ?presence
feature and high activations ofvx
feature of other vehicles. Do these two policies correlate highly for the agent or largely stay independent ?Acknowledgements:
My sincere thanks to this amazing community who have made Interpretability research easily accessible reachable to general public. I hope that my experiments bring some value to others and to this community. I look forward to delve deeper in this topic, any support & guidance is highly appreciated.
I would also like to thank BlueDot impact for running a 12 week online course on AI Safety fundamentals. I conducted this experiment as part of the project submission phase of this course and I am grateful to their course facilitators and their team for conducting amazing sessions and providing a comprehensive list of resources on the key topics.
I'm looking forward to collaborating. Reach out to me on
My portfolio
LinkedIn
Appendix:
Github source code
Glossary:
DQN: Deep Q-Network
FCN: Fully Convolutional Net
CNN: Convolutional Neural Net
QK Circuit: Query-Key Circuit
OV Circuit: Output-Value Circuit
Model architecture:
Environment configuration:
N: number of vehicles
Observations type: Kinematics
Observation space: 7
s=(si)i∈[0,N]
where si=[xiyivxivyicos(ϕ)isin(ϕ)i]T
Action space: 3
{SLOWER, NO-OP, FASTER}
Hyper-parameters:
Gamma: 0.95
Replay buffer size: 15000
Batch size: 64
Exploration strategy: Epsilon greedy
Tau: 15000
Initial temperature: 1.0
Final temperature: 0.05
Evaluation:
Running evaluation over 10 episodes with display enabled shows high scores and successful navigation through the intersection.
Learned Attention scores:
Difference is computed by doing a square difference per term of the 2-dim tensor.
Difference is computed by doing a square difference per term of the 2-dim tensor.
Social Attention for Autonomous Decision-Making in Dense Traffic
https://distill.pub/2020/understanding-rl-vision/
https://transformer-circuits.pub/2021/framework/index.html