I think KL/entropy regularization is usually used to prevent mode collapse partly because it has nice theoretical properties. In particular, it is easy to reason about the optimal policy for the regularized objective - see for example the analysis in the paper Equivalence Between Policy Gradients and Soft Q-Learning.
Nevertheless, action-dependent baselines do appear in the literature, although the story is a bit confusing. This is my understanding of it from some old notes:
I like the philosophical and strategic take here: let's avoid wireheading, arbitrary reinforcement strength is risky[1], hopefully we can get some values-caring-about-human-stuff.
The ACTDE seems potentially a nice complement/alternative to entropy[2] regularisation for avoiding mode collapse (I haven't evaluated deeply). I think you're misdiagnosing a few things though.
Overall I think the section about oscillating advantage/value estimation is irrelevant (interesting, but unrelated), and I think you should point the finger less at PPO and advantage estimation per se and more at exploration at large. And you might want to flag that too much exploration/randomness can also be an issue!
Though note that ideally, once we actually know with confidence what is best, we should be near-greedy about it, rather than softmaxing! Say it was 'ice cream' vs 'slap in the face'. I would infinitely (linearly in time) regret softmaxing over that for eternity. As it stands I think humanity is very far from being able to safely aggressively greedily optimise really important things, but this is at least a consideration to keep in mind. ↩︎
Incidentally, KL divergence regularisation is not primarily for avoiding mode collapse AFAIK, it's for approximate trust region constraints - which may incidentally help to avoid mode collapse by penalising large jumps away from initially-high-entropy policies. See the TRPO paper. Entropy regularisation directly addresses mode collapse. ↩︎
this kind of failure happens by default in policy gradient methods.
It looks like you're kind of agreeing here that value estimate oscillation isn't the culprit? Again I think this is pretty standard - though the finger is usually not pointed at any particular value estimator or whatnot, but rather at the greediness of updating only on so-far-observed data i.e. the exploration problem. The GLIE conditions[1] - Greedy in the Limit with Infinite Exploration - are a classic result. Hence the plethora of exploration techniques which are researched and employed in RL.
Techniques like confidence bounding[2] based on Hoeffding's inequality and Thompson sampling based on Bayesian uncertainty require more than a simple mean estimate (which is all that a value or advantage is): typically at least also one spread/uncertainty estimate[3]. Entropy regularisation, epsilon 'exploration', intrinsic 'curiosity' rewards, value-of-information estimation and so on are all heuristics for engaging with exploration.
I don't know what's a good resource on GLIE, but you can just look up Greedy in the Limit with Infinite Exploration ↩︎
Amazingly there's no Wikipedia entry on UCB?? ↩︎
Epsilon exploration can get away without a spread estimate, but its GLIE guarantees are only provided if there's an epsilon per state, which secretly smuggles in an uncertainty estimate (because you're tracking the progress bar on each state somehow, which means you're tracking how often you've seen it). ↩︎
Though note that ideally, once we actually know with confidence what is best, we should be near-greedy about it, rather than softmaxing!
I disagree. I don't view reward/reinforcement as indicating what is "best" (from our perspective), but as chiseling decision-making circuitry into the AI (which may then decide what is "best" from its perspective). One way of putting a related point: I think that we don't need to infinitely reinforce a line of reasoning in order to train an AI which reasons correctly.
(I want to check -- does this response make sense to you? Happy to try explaining my intuition in another way.)
There's also the issue of non-ergodic/nonstationary environments (if I try out breaking my leg to see what happens, I might not be able to try out other stuff later!) which defeat the GLIE and can cause another kind of collapse. Actually behaving sufficiently entropically is risky in such environments, hence research into safe exploration.
The problem is that this advantage can oscillate forever.
This is a pretty standard point in RL textbooks. But the culprit is the learning rate (which you set to be 1 in the example, but you can construct a nonconverging case for any constant )! The advantage definition itself is correct and non-oscillating, it's the estimation of the expectation using a moving average which is (sometimes) at fault.
Oscillating or nonconvergent value estimation is not the cause of policy mode collapse.
The advantage definition itself is correct and non-oscillating... Oscillating or nonconvergent value estimation is not the cause of policy mode collapse.
The advantage is (IIUC) defined with respect to a given policy, and so the advantage can oscillate and then cause mode collapse. I agree that a constant learning rate schedule is problematic, but note that ACTDE converges even with a constant learning rate schedule. So, I would indeed say that oscillating value estimation caused mode collapse in the toy example I gave?
Would this be equivalent to an RL environment that scales down the per wedding reward for repeated weddings?
What bothers me about this is suppose we have a different set of 2 RL choices:
Life saved + 10
Murder -10
In this case we want the agent to choose policies that result in life saved with total mode collapse away from committing a murder. This is also true for less edgy/more practical descriptions, such as:
box shelved correctly 0.1
human coworker potentially injured -10
Is this identical to training the next-to-last layer to predict the rewards directly, and then just transforming those predictions to get a sample? Have you considered going whole hog on model-based RL here?
I'd be interested in avoiding mode collapse in cases where that's not practical, like diffusion models. Actually, could you choose a reward that makes diffusion models equivalent to MCMC? Probably no good safety reason to do such a thing though.
Is this identical to training the next-to-last layer to predict the rewards directly, and then just transforming those predictions to get a sample?
In the tabular case, that's equivalent given uniform . Maybe it's also true in the function approximator PG regime, but that's a maybe -- depends on inductive biases. But often we want a pretrained (like when doing RLHF on LLMs), which isn't uniform.
TL;DR: We present an advantage variant which, in certain settings, does not train an optimal policy, but instead uses a fixed reward to update a policy a fixed amount from initialization. Non-tabular empirical results seem mixed: The policy doesn't mode-collapse, but has unclear convergence properties.
Summary: Many policy gradient methods allow a network to extract arbitrarily many policy updates from a single kind of reinforcement event (e.g. for outputting tokens related to weddings). Alex proposes a slight modification to the advantage equation, called "action-conditioned TD error" (ACTDE). ACTDE ensures that the network doesn't converge to an "optimal" policy (these almost always put infinite logits on a single action). Instead, ACTDE updates the network by a fixed number of logits.
For example, suppose R(pizza)=10 and R(cookies)=11. In this case, PPO converges to a policy which puts arbitrarily many logits on cookies, even though the reward difference is small. By contrast, under ACTDE, the network converges to the softmax-over-reward policy {pizza: 27%, cookies: 73%}, which seems more reasonable.
Then, Michael Einhorn shares initial results which support Alex's theoretical predictions. Using a similar architecture and Q-head loss function to ILQL for a small transformer trained in a prisoner's dilemma, Michael Einhorn collected initial data on ACTDE. Unlike PPO, ACTDE-trained policies did not mode collapse onto a single action and instead learned mixed strategies.
We're interested in additional experiments on ACTDE. We hope that, by using ACTDE instead of advantage, we can automatically mitigate "reward specification" issues and maybe even reduce the need for a KL penalty term. That would make it easier to shape policies which do what we want.
The advantage equation implies arbitrary amounts of update on a single experience
In PPO, the optimization objective is proportional to the advantage given a policy π, reward function R, and on-policy value function vπ:[1]
Aπ(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−vπ(s).Alex thinks this equation is actually pretty messed up, although it looked decent at first. The problem is that this advantage can oscillate forever. To explain, let's consider a simple bandit problem—one state ("We had a") and two actions ("wedding" and "party") with rewards R(“We had a wedding”)=1 and R(“We had a party”)=.5.
The failure which happens is:
This continues to happen, which means that "wedding" gets arbitrarily high logits.
This flaw is easiest to see formally. Initialize the t=0 tabular value function vπ0 to 0, and the policy π0 to be 50/50 for “party”/“wedding”. Let γ=1, and we update the value function v using tabular TD learning (with learning rate α=1). So, for example, if the system takes the “wedding” action, its new value function vπ1(s)=1. If the system then takes the “party” action, the value snaps back to vπ2(s)=.5.[2]
The policy update rule is: If the advantage Aπ(s,a)=n, then action a becomes n bits more probable under π (i.e. we add n to π's logits on a). So, if π0(s,“ wedding”)=.5 and advantage Aπ0(s,“ wedding")=1, then π1(s,“ wedding”)=.73.
Episode-by-episode:
With probability 1 as t→∞, πt(wedding)→1. You might think this is good, since wedding is in fact “optimal” at that state. This does not seem good. Here are a few kinds of explanations for why:
This doesn’t seem limited to tabular TD-learning, or PPO in more realistic domains. EG vanilla policy gradient will also allow a system to extract an unbounded amount of reinforcement from a single kind of event (e.g. “wedding”). Unless very specific care is taken, Alex thinks this kind of failure happens by default in policy gradient methods.
Action-conditioned TD error avoids arbitrarily high logits
Given the original advantage equation:
Aπ(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−vπ(s),replace the last term’s baseline to account for the taken action:
Aπ∗(s,a):=Es′∼T(s,a)[R(s,a,s′)+γvπ(s′)]−qπ(s,a).We call this “action-conditioned TD error” (ACTDE).
ACTDE allows the system to account for its decision to go off-policy by selecting a new action a which isn’t the usual recommendation a′∼π(s). Philosophically, Alex wanted to mimic reward prediction error. The network taking a different action is not surprising to the network, so the optimization term should account for the action taken (i.e. by using qπ(s,a)).
Re-analyzing the situation:
The policy quickly converges to the softmax logits over the reward for the next completions, where e1e1+e.5≈.63. That is, the learned policy has R(“party”)=.5 logits on “party” and R(“wedding”)=1 logit on “wedding”. Therefore this process does not converge to the optimal policy, even in the limit of infinite exploration. Correspondingly, there is no mode collapse in this situation. Reward logits are “added to” initialization logits π0 (the prior over what completions to output). RL, in this setting, provides a finite amount of reinforcement for certain kinds of computation/actions.
Furthermore, self-consistent, Bellman-backed-up Q-functions will have zero advantage and zero updates. Networks aren’t penalized for exploring, and there’s a precise and finite amount of reinforcement which can occur given current predictions about future value, as represented by qt. And training should be more stable, with fewer fluctuations in advantage with respect to the policy itself.[3]
ACTDE doesn't mode-collapse onto wireheading
ACTDE doesn't mode collapse on wireheading, even given that the network tries out wireheading! (Which Alex thinks is not that likely for practical RL algorithms.)
Concretely, suppose that reward is 10 if you eat pizza and 100 if you wirehead. You start off with action distribution {pizza: 1%, wirehead: 99%}, and we're doing TD-learning in the tabular setup we just described. If so, then the policy gradients upweight wireheading more and more. This can happen until the network puts arbitrarily many logits on the wireheading action. In this situation, under these exploration assumptions and with probability 1, PPO "selects for" wireheading and the policy ends up {pizza: ϵ, wirehead: 1−ϵ}.
However, ACTDE does not lead to arbitrarily many logits on wireheading. Instead, ACTDE leads to the softmax distribution over actions, with the softmax taken over the reward for each action. Thus, the "optimum"/fixed-point policy of tabular ACTDE is about { pizza: .02%, wirehead: 99.98% }. That's still mostly wireheading, but there are only finitely many logits on that action.
PPO vs ACTDE on the iterated prisoner's dilemma
In this toy experiment, the model plays prisoner's dilemmas against its past self, similar to the idea by Krueger et. al. The model is mingpt with a vocab size of two: one token for "cooperate", and one for "defect". mingpt has 3 layers and an embedding dimension of 12. The model sees the history of cooperates and defections, and outputs the next action.
We are not training via self play against a copy. Instead the model at time t plays against its action at time t−1. Playing with its past self for a sequence of
ccddc
has 4 games:cc
,cd
,dd
,dc
, with rewards of 0.5 (forcc
), 2 (forcd
), -0.74 (fordd
), and -1.76 (fordc
).[4]Alternating cooperation (
c
) and defection (d
) is the (bolded) optimal strategy for both start states:cccc...
cddd...
cdcd...
dddd...
dccc...
dcdc...
What we're testing: If ACTDE mode collapses when used on function approximators (like mingpt), then the theoretical predictions above are wrong.
PPO results
PPO immediately learns the alternating strategy:
ACTDE results
The model does not collapse onto a pure strategy. Instead, the results are inconsistent across trials. However, ACTDE does reliably:
Here's the first 1K epochs of a training run:
Zooming out to all 10K epochs:
We ran 10 trials and plotted the mean and standard deviation of average returns:
There seems to be very slow convergence,[6] perhaps towards the softmax-over-returns policy (shown by the dotted lines), or towards the uniform policy. We lean towards "convergence to uniform" due to evidence from a trial on a different reward matrix:
Overall, ACTDE's results are sensitive to variations in the algorithm such as whitening advantages, detaching the value and Q-heads, and using the loss function from PPO or ILQL for the value head.
Speculation
This method might not work very well for e.g. RLHF at scale. Deep RL is notoriously finicky. Furthermore, it would be pretty disappointing if ACTDE generally converges on uniform policies, and that seems like a live possibility given the last graph above.
However, Alex has a few intuitions anyways:
Summary
ACTDE seems to avoid mode collapse in simple tabular setups. We showed that ACTDE doesn't mode collapse on a toy prisoner's dilemma learning task, but instead trains a mixed strategy.
We're excited for someone to RLHF a language model using ACTDE. Alex is willing to contribute 30 minutes weekly to giving feedback on such a project, insofar as that would be helpful. If necessary, Alex can also help acquire funding for a prospective researcher who has experience doing deep RL. Email him at
alex@turntrout.com
if interested. Email Michael ateinhorn.michael1@gmail.com
for any questions about the code.Contributions:
trl_textworld
[8] andprisonerUnitTest
.Thanks to Connor Leahy, Evan Hubinger, Ulisse Mini, Nate Soares, Leo Gao, Garrett Baker, janus, David Krueger and others for thoughts and discussions.
Appendix: Random notes
This advantage equation, as given, can also be called the "TD error."
Alex thinks that using a fixed learning rate 0<α<1 shouldn’t fix PPO's "infinite logit update issue", but a decaying learning rate schedule probably does. This isn't that surprising, and he doesn't think it fixes the deeper potential issue with fluctuating value baselines.
Although Alex hasn't analyzed the sequential tabular setting — possibly infinite logit updating can still happen there?
Note that the
cd
anddc
always come in pairs except for at most 1 extra.return(s) averages strategy s's return over the first state being cooperate
c
and being defectd
.In the tabular bandit example above, the convergence was extremely fast due to the learning rate and triviality of the problem.
When Alex wrote this in the fall, he thought that RLHF was responsible for mode collapse behaviors in LMs. However, empirical evidence has since made him think that RLHF is less responsible for these failures. He thinks his theoretical analysis is still correct under the assumptions he made, and he still thinks it's important to investigate empirically.
One of the goals of
trl-textworld
was to evaluate PPO vs ACTDE finetunings on pretrained language models, but the models were not able to learn to play the text adventure, so this project did not get to a point where the algorithm's results could be compared. The implementation may still be useful—it has been tested up to GPT-NeoX 20B on 8 GPUs.