已经用了两篇文章讲解了PPO的源码解读:
最后我们在来看看模型训练过程的一些细节。
1.PPO训练过程
1.1. 核心源码
PPO训练过程:详见PPOtrainer源码的ppo_train()入口函数。核心代码块如下:
class PPOTrainer(ABC):
################
# 1.loss定义 (Actor模型两个loss, Critic模型一个loss)
################
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.ptx_loss_fn = GPTLMLoss()
def ppo_train(self, global_steps=0):
################
# 2. 加载经验数据(Experience)
################
dataloader = DataLoader(...)
for epoch in range(self.max_epochs):
for experience in pbar:
################
# 3. 执行一步训练
################
status = self.training_step(experience, global_steps)
def training_step(self, experience: Experience, global_steps) -> Dict[str, float]:
################
# 3.1. 训练Actor 模型,支持2种任务同时训练(SFT和PPO),对应loss GPTLMLoss, PolicyLoss
################
status = self.training_step_actor(experience)
################
# 3.2. 训练Critic 模型,通过Valueloss计算损失
################
status.update(self.training_step_critic(experience))
上述代码流程描述可知,PPO训练过程,在一个训练步骤中,Actor和Critic模型依次训练更新。在训练Actor模型时,代码实现中加入了一个可配置的SFT任务,所以Actor是可以同时多任务训练的。具体训练如下图所示。
1.2. 模型训练框图

图1、PPO训练框图
Actor 和Critic 模型结构详见:姜富春:OpenRLHF源码解读:1.理解PPO单机训练 第2部分:模型结构部分的网络图。
当前我们基本整理清楚了PPO的完整训练流程。接下来我们进一步看下3个loss函数,理解下模型计算损失的过程。
1.3. loss解读
1.3.1. GPTLMLoss
GPTLMLoss核心代码块,如下:
# 源码:https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py#L11C7-L11C16
class GPTLMLoss(nn.Module):
def __init__(self, ring_attn_group=None):
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

class PolicyLoss(nn.Module):
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
#################
#1. 重要性采样 important-sampling
# 下面公式:(log(p) - log(p')).exp() = log(p/p').exp() = p/p'
# 转换下就两个概率的比,表示重要性采样,保证PPO算法是个off-policy算法,提升训练效率
#################
ratio = (log_probs - old_log_probs).exp()
#################
# 2. clip-PPO 算法,详见下方公式
#################
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask, dim=-1).mean()
return loss

class ValueLoss(nn.Module):
def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
returns: torch.Tensor) -> torch.Tensor:
##############
# 嚯,计算regrenssion loss(MSE)
##############
loss = (values - returns) ** 2
ValueLoss计算就是对比状态预估价值(values)和实际计算的经验价值(returns)的相近程度,典型的回归问题。用MSE(Mean Squared Loss)计算损失。
2.总结
本文对PPO采样后的train过程的源码和Loss函数做了详细的讲解。
至此,通过三篇文档已经描述了PPO单机训练的完整过程。其他两篇详见:
水平有限,欢迎指正~
—文章来源 知乎 链接:https://zhuanlan.zhihu.com/p/14813158239
Comments NOTHING