$$ \newcommand{\bone}{\mathbf{1}} \newcommand{\bbeta}{\mathbf{\beta}} \newcommand{\bdelta}{\mathbf{\delta}} \newcommand{\bepsilon}{\mathbf{\epsilon}} \newcommand{\blambda}{\mathbf{\lambda}} \newcommand{\bomega}{\mathbf{\omega}} \newcommand{\bpi}{\mathbf{\pi}} \newcommand{\bphi}{\mathbf{\phi}} \newcommand{\bvphi}{\mathbf{\varphi}} \newcommand{\bpsi}{\mathbf{\psi}} \newcommand{\bsigma}{\mathbf{\sigma}} \newcommand{\btheta}{\mathbf{\theta}} \newcommand{\btau}{\mathbf{\tau}} \newcommand{\ba}{\mathbf{a}} \newcommand{\bb}{\mathbf{b}} \newcommand{\bc}{\mathbf{c}} \newcommand{\bd}{\mathbf{d}} \newcommand{\be}{\mathbf{e}} \newcommand{\boldf}{\mathbf{f}} \newcommand{\bg}{\mathbf{g}} \newcommand{\bh}{\mathbf{h}} \newcommand{\bi}{\mathbf{i}} \newcommand{\bj}{\mathbf{j}} \newcommand{\bk}{\mathbf{k}} \newcommand{\bell}{\mathbf{\ell}} \newcommand{\bm}{\mathbf{m}} \newcommand{\bn}{\mathbf{n}} \newcommand{\bo}{\mathbf{o}} \newcommand{\bp}{\mathbf{p}} \newcommand{\bq}{\mathbf{q}} \newcommand{\br}{\mathbf{r}} \newcommand{\bs}{\mathbf{s}} \newcommand{\bt}{\mathbf{t}} \newcommand{\bu}{\mathbf{u}} \newcommand{\bv}{\mathbf{v}} \newcommand{\bw}{\mathbf{w}} \newcommand{\bx}{\mathbf{x}} \newcommand{\by}{\mathbf{y}} \newcommand{\bz}{\mathbf{z}} \newcommand{\bA}{\mathbf{A}} \newcommand{\bB}{\mathbf{B}} \newcommand{\bC}{\mathbf{C}} \newcommand{\bD}{\mathbf{D}} \newcommand{\bE}{\mathbf{E}} \newcommand{\bF}{\mathbf{F}} \newcommand{\bG}{\mathbf{G}} \newcommand{\bH}{\mathbf{H}} \newcommand{\bI}{\mathbf{I}} \newcommand{\bJ}{\mathbf{J}} \newcommand{\bK}{\mathbf{K}} \newcommand{\bL}{\mathbf{L}} \newcommand{\bM}{\mathbf{M}} \newcommand{\bN}{\mathbf{N}} \newcommand{\bP}{\mathbf{P}} \newcommand{\bQ}{\mathbf{Q}} \newcommand{\bR}{\mathbf{R}} \newcommand{\bS}{\mathbf{S}} \newcommand{\bT}{\mathbf{T}} \newcommand{\bU}{\mathbf{U}} \newcommand{\bV}{\mathbf{V}} \newcommand{\bW}{\mathbf{W}} \newcommand{\bX}{\mathbf{X}} \newcommand{\bY}{\mathbf{Y}} \newcommand{\bZ}{\mathbf{Z}} \newcommand{\bsa}{\boldsymbol{a}} \newcommand{\bsb}{\boldsymbol{b}} \newcommand{\bsc}{\boldsymbol{c}} \newcommand{\bsd}{\boldsymbol{d}} \newcommand{\bse}{\boldsymbol{e}} \newcommand{\bsoldf}{\boldsymbol{f}} \newcommand{\bsg}{\boldsymbol{g}} \newcommand{\bsh}{\boldsymbol{h}} \newcommand{\bsi}{\boldsymbol{i}} \newcommand{\bsj}{\boldsymbol{j}} \newcommand{\bsk}{\boldsymbol{k}} \newcommand{\bsell}{\boldsymbol{\ell}} \newcommand{\bsm}{\boldsymbol{m}} \newcommand{\bsn}{\boldsymbol{n}} \newcommand{\bso}{\boldsymbol{o}} \newcommand{\bsp}{\boldsymbol{p}} \newcommand{\bsq}{\boldsymbol{q}} \newcommand{\bsr}{\boldsymbol{r}} \newcommand{\bss}{\boldsymbol{s}} \newcommand{\bst}{\boldsymbol{t}} \newcommand{\bsu}{\boldsymbol{u}} \newcommand{\bsv}{\boldsymbol{v}} \newcommand{\bsw}{\boldsymbol{w}} \newcommand{\bsx}{\boldsymbol{x}} \newcommand{\bsy}{\boldsymbol{y}} \newcommand{\bsz}{\boldsymbol{z}} \newcommand{\bsA}{\boldsymbol{A}} \newcommand{\bsB}{\boldsymbol{B}} \newcommand{\bsC}{\boldsymbol{C}} \newcommand{\bsD}{\boldsymbol{D}} \newcommand{\bsE}{\boldsymbol{E}} \newcommand{\bsF}{\boldsymbol{F}} \newcommand{\bsG}{\boldsymbol{G}} \newcommand{\bsH}{\boldsymbol{H}} \newcommand{\bsI}{\boldsymbol{I}} \newcommand{\bsJ}{\boldsymbol{J}} \newcommand{\bsK}{\boldsymbol{K}} \newcommand{\bsL}{\boldsymbol{L}} \newcommand{\bsM}{\boldsymbol{M}} \newcommand{\bsN}{\boldsymbol{N}} \newcommand{\bsP}{\boldsymbol{P}} \newcommand{\bsQ}{\boldsymbol{Q}} \newcommand{\bsR}{\boldsymbol{R}} \newcommand{\bsS}{\boldsymbol{S}} \newcommand{\bsT}{\boldsymbol{T}} \newcommand{\bsU}{\boldsymbol{U}} \newcommand{\bsV}{\boldsymbol{V}} \newcommand{\bsW}{\boldsymbol{W}} \newcommand{\bsX}{\boldsymbol{X}} \newcommand{\bsY}{\boldsymbol{Y}} \newcommand{\bsZ}{\boldsymbol{Z}} \newcommand{\calA}{\mathcal{A}} \newcommand{\calB}{\mathcal{B}} \newcommand{\calC}{\mathcal{C}} \newcommand{\calD}{\mathcal{D}} \newcommand{\calE}{\mathcal{E}} \newcommand{\calF}{\mathcal{F}} \newcommand{\calG}{\mathcal{G}} \newcommand{\calH}{\mathcal{H}} \newcommand{\calI}{\mathcal{I}} \newcommand{\calJ}{\mathcal{J}} \newcommand{\calK}{\mathcal{K}} \newcommand{\calL}{\mathcal{L}} \newcommand{\calM}{\mathcal{M}} \newcommand{\calN}{\mathcal{N}} \newcommand{\calO}{\mathcal{O}} \newcommand{\calP}{\mathcal{P}} \newcommand{\calQ}{\mathcal{Q}} \newcommand{\calR}{\mathcal{R}} \newcommand{\calS}{\mathcal{S}} \newcommand{\calT}{\mathcal{T}} \newcommand{\calU}{\mathcal{U}} \newcommand{\calV}{\mathcal{V}} \newcommand{\calW}{\mathcal{W}} \newcommand{\calX}{\mathcal{X}} \newcommand{\calY}{\mathcal{Y}} \newcommand{\calZ}{\mathcal{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\C}{\mathbb{C}} \newcommand{\N}{\mathbb{N}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\F}{\mathbb{F}} \newcommand{\Q}{\mathbb{Q}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \newcommand{\nnz}[1]{\mbox{nnz}(#1)} \newcommand{\dotprod}[2]{\langle #1, #2 \rangle} \newcommand{\ignore}[1]{} \let\Pr\relax \DeclareMathOperator*{\Pr}{\mathbf{Pr}} \newcommand{\E}{\mathbb{E}} \DeclareMathOperator*{\Ex}{\mathbf{E}} \DeclareMathOperator*{\Var}{\mathbf{Var}} \DeclareMathOperator*{\Cov}{\mathbf{Cov}} \DeclareMathOperator*{\stddev}{\mathbf{stddev}} \DeclareMathOperator*{\avg}{avg} \DeclareMathOperator{\poly}{poly} \DeclareMathOperator{\polylog}{polylog} \DeclareMathOperator{\size}{size} \DeclareMathOperator{\sgn}{sgn} \DeclareMathOperator{\dist}{dist} \DeclareMathOperator{\vol}{vol} \DeclareMathOperator{\spn}{span} \DeclareMathOperator{\supp}{supp} \DeclareMathOperator{\tr}{tr} \DeclareMathOperator{\Tr}{Tr} \DeclareMathOperator{\codim}{codim} \DeclareMathOperator{\diag}{diag} \newcommand{\PTIME}{\mathsf{P}} \newcommand{\LOGSPACE}{\mathsf{L}} \newcommand{\ZPP}{\mathsf{ZPP}} \newcommand{\RP}{\mathsf{RP}} \newcommand{\BPP}{\mathsf{BPP}} \newcommand{\P}{\mathsf{P}} \newcommand{\NP}{\mathsf{NP}} \newcommand{\TC}{\mathsf{TC}} \newcommand{\AC}{\mathsf{AC}} \newcommand{\SC}{\mathsf{SC}} \newcommand{\SZK}{\mathsf{SZK}} \newcommand{\AM}{\mathsf{AM}} \newcommand{\IP}{\mathsf{IP}} \newcommand{\PSPACE}{\mathsf{PSPACE}} \newcommand{\EXP}{\mathsf{EXP}} \newcommand{\MIP}{\mathsf{MIP}} \newcommand{\NEXP}{\mathsf{NEXP}} \newcommand{\BQP}{\mathsf{BQP}} \newcommand{\distP}{\mathsf{dist\textbf{P}}} \newcommand{\distNP}{\mathsf{dist\textbf{NP}}} \newcommand{\eps}{\epsilon} \newcommand{\lam}{\lambda} \newcommand{\dleta}{\delta} \newcommand{\simga}{\sigma} \newcommand{\vphi}{\varphi} \newcommand{\la}{\langle} \newcommand{\ra}{\rangle} \newcommand{\wt}[1]{\widetilde{#1}} \newcommand{\wh}[1]{\widehat{#1}} \newcommand{\ol}[1]{\overline{#1}} \newcommand{\ul}[1]{\underline{#1}} \newcommand{\ot}{\otimes} \newcommand{\zo}{\{0,1\}} \newcommand{\co}{:} %\newcommand{\co}{\colon} \newcommand{\bdry}{\partial} \newcommand{\grad}{\nabla} \newcommand{\transp}{^\intercal} \newcommand{\inv}{^{-1}} \newcommand{\symmdiff}{\triangle} \newcommand{\symdiff}{\symmdiff} \newcommand{\half}{\tfrac{1}{2}} \newcommand{\mathbbm}{\Bbb} \newcommand{\bbone}{\mathbbm 1} \newcommand{\Id}{\bbone} \newcommand{\SAT}{\mathsf{SAT}} \newcommand{\bcalG}{\boldsymbol{\calG}} \newcommand{\calbG}{\bcalG} \newcommand{\bcalX}{\boldsymbol{\calX}} \newcommand{\calbX}{\bcalX} \newcommand{\bcalY}{\boldsymbol{\calY}} \newcommand{\calbY}{\bcalY} \newcommand{\bcalZ}{\boldsymbol{\calZ}} \newcommand{\calbZ}{\bcalZ} $$

2024

  1. Aviral Kumar, Vincent Zhuang, Rishabh Agarwal, and 15 more authors
    Sep 2024

    Paper Abstract

    Self-correction is a highly desirable capability of large language models (LLMs), yet it has consistently been found to be largely ineffective in modern LLMs. Current methods for training self-correction typically depend on either multiple models, a more advanced model, or additional forms of supervision. To address these shortcomings, we develop a multi-turn online reinforcement learning (RL) approach, SCoRe, that significantly improves an LLM’s self-correction ability using entirely self-generated data. To build SCoRe, we first show that variants of supervised fine-tuning (SFT) on offline model-generated correction traces are often insufficient for instilling self-correction behavior. In particular, we observe that training via SFT falls prey to either a distribution mismatch between mistakes made by the data-collection policy and the model’s own responses, or to behavior collapse, where learning implicitly prefers only a certain mode of correction behavior that is often not effective at self-correction on test problems. SCoRe addresses these challenges by training under the model’s own distribution of self-generated correction traces and using appropriate regularization to steer the learning process into learning a self-correction behavior that is effective at test time as opposed to fitting high-reward responses for a given prompt. This regularization process includes an initial phase of multi-turn RL on a base model to generate a policy initialization that is less susceptible to collapse, followed by using a reward bonus to amplify self-correction. With Gemini 1.0 Pro and 1.5 Flash models, we find that SCoRe achieves state-of-the-art self-correction performance, improving the base models’ self-correction by 15.6% and 9.1% respectively on MATH and HumanEval.

Three Important Things

1. Self-Correction is Hard

It is a desirable property for LLMs to be able to recognize errors in their own output, and to correct it. This is called intrinsic self-correction. This is possible in theory because LLMs already possess sufficient the knowledge to solve the problem, but oftentimes fail to elicit and make the correct inferences necessary. For instance, it is capable of completing each sub-part of a proof when provided with the rest of the proof, but fails to complete the entire proof by itself.

People have attempted to address this with prompting and supervised fine-tuning (SFT). Prompting rarely works, and SFT suffers from two major problems: distribution shift and behavior collapse. Distribution shift happens when the model learns SFT data of correcting mistakes that doesn’t necessarily correlate to the mistakes that it’ll make itself. Behavior collapse is where the model only attempts to mimic the structure of the correction, by giving a best first-attempt response and then providing a superficial change in its second attempt, which is not in the spirit of self-correcting behavior.

The authors also investigated and found that SFT on on-policy behavior still resulted in behavior collapse.

This motivates the need to use RL to impart this behavior, with the authors introducing their approach called SCoRe.

Samples of SCoRe’s 2-turn self correcting behavior:

2. Self-Correction via Reinforcement Learning (SCoRe)

Their goal is to train a model that will provide an answer on the first turn, and then improve on it in the second turn. In practice, this can be extended to an arbitrary number of corrections, but they only explored two turns due to compute limitations.

Using on-policy RL in SCoRe addresses distribution shift directly since it is now training on its own behavior.

To fix behavior collapse, they used a 2-stage training process:

  1. In the first stage of training, the model is constrained to answer similarly to the base model in the first turn, whilst maximizing reward by aiming to answer correctly and correct its mistakes in the second turn
  2. In the second stage of training, the model is allowed to maximize reward across both turns, but with an extra reward shaping term that also rewards progress between the first and second turn

3. Stage I: Training an Initialization that Decouples Attempts

The goal of Stage I is to prevent the model from falling into the behavior collapse trap by learning to just output a good answer on its first try, hence the KL constraint on the model to be similar to the base model.

Formally, if \(\boldsymbol{x}_1\) is the input, \(\boldsymbol{y}_1, \boldsymbol{y}_2\) are the first and second turn outputs, \(\boldsymbol{y}^*\) is the correct answer, and \(p_1\) is the auxiliary instruction to find a mistake and improve the response, then the training objective is:

\[\max _\theta \mathbb{E}_{\boldsymbol{x}_1, \boldsymbol{y}_1 \sim \pi_\theta(\cdot \mid \boldsymbol{x}), \boldsymbol{y}_2 \sim \pi_\theta\left(\cdot \mid\left[\boldsymbol{x}_1, p_1\right]\right)}\left[\hat{r}\left(\boldsymbol{y}_2, \boldsymbol{y}^*\right)-\beta_2 D_{K L}\left(\pi_\theta\left(\cdot| | \boldsymbol{x}_1\right)| | \pi_{\mathrm{ref}}\left(\cdot \mid \boldsymbol{x}_1\right)\right)\right]\]

Note that I believe it should actually be \(\boldsymbol{y}_2 \sim \pi_\theta\left(\cdot \mid\left[\boldsymbol{x}_1, p_1, \boldsymbol{y}_2\right]\right)\), since the second turn should depend on the output from the first turn.

This means we first sample a first turn, then append the auxiliary instructions to encourage self-correction behavior and sample a second turn, and we want to maximize the correctness of getting the right answer on the second turn, constrained by a KL penalty between the policy on the first turn and the base model,

They used the following prompt for \(p_1\):

Self-correction instruction. There might be an error in the solution above
because of lack of understanding of the question. Please correct the error, if
any, and rewrite the solution.

3. Stage II: Multi-Turn RL with Reward Shaping

In the second stage, we augment the objective with the correctness of the first phase, as well as a bonus term to encourage correction:

\[\max _\theta \mathbb{E}_{\boldsymbol{x}_1, \boldsymbol{y}_1 \sim \pi_\theta(\cdot \mid x), \boldsymbol{y}_2 \sim \pi_\theta\left(\cdot \mid\left[x_1, p_1\right]\right)}\left[ \hat{r}\left(\boldsymbol{y}_1, \boldsymbol{y}^*\right) + \hat{r}\left(\boldsymbol{y}_2, \boldsymbol{y}^*\right) + \hat{r}\left(\boldsymbol{y}_2, \boldsymbol{y}^*\right) + \hat{b}\left(\boldsymbol{y}_2 \mid \boldsymbol{y}_1, \boldsymbol{y}^*\right) -\beta_1 D_{K L}\left(\pi_\theta\left(\cdot \mid \boldsymbol{x}_i\right)| | \pi_{\mathrm{ref}}\left(\cdot \mid \boldsymbol{x}_i\right)\right)\right],\]

where the bonus term is defined as the improvement between the first and second attempts scaled by \(\alpha\), i.e

\[\hat{b}\left(\boldsymbol{y}_2 \mid \boldsymbol{y}_1, \boldsymbol{y}^*\right) = \alpha \cdot\left(\hat{r}\left(\boldsymbol{y}_2, \boldsymbol{y}^*\right)-\widehat{r}\left(\boldsymbol{y}_1, \boldsymbol{y}^*\right)\right).\]

The goal of the bonus term is to discourage behavior collapse, and where the model just learns to output the correct answer on the first try.

4. Results

SoTA results:

I think a missing datapoint for comparison would be base model trained on RL to just maximize turn 1 or turn 2 results. It feels a bit unfair to use the base model as a baseline as it missed out on the additional juicy reasoning training all the other approaches had.

Most Glaring Deficiency

The conflicting objectives between maximizing correctness on turn 2 whilst desiring for an improvement between steps felt a bit awkward, especially with respect to the KL constraint on the base model on the first turn. This may limit the applicability of the approach for more large-scale RL in improving reasoning, if we do indeed want the model to learn and diverge away from the limited capabilities of the base model.

It is also not clear how this approach is extendable to more turns. It also requires the inclusion of the auxiliary instructions, which makes it not “truly” intrinsic.

Conclusions for Future Work

Reward shaping can be used to guide the model in elicit behaviors we’d like it to fundamentally learn, and not just perform superficial imitation of formats.