LLaMA-Factory 实现了三种主要的强化学习训练方法:PPO、DPO、KTO,每种方法都针对大型语言模型的微调采用不同的策略。
PPO 是一种基于策略梯度的强化学习算法,是最传统的 RLHF (Reinforcement Learning from Human Feedback) 方法。
PPO 的目标函数:
展开代码L_PPO(θ) = E_π_old [ min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) ]
其中:
python展开代码# src/llamafactory/train/ppo/trainer.py
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
    # ...
    for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
        # 获取批次数据
        batch = next(dataiter)
        
        # 获取输入
        self.model.eval()
        queries, responses, rewards = [], [], []
        for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
            mini_batch = {
                "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
                "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
            }
            mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
            mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
            queries.extend(mini_batch_queries)
            responses.extend(mini_batch_responses)
            rewards.extend(mini_batch_rewards)
        # 执行PPO步骤
        self.model.train()
        stats = self.step(queries, responses, rewards)
        # ...
DPO 是一种直接从人类偏好数据中学习的方法,避免了显式的奖励模型。
DPO 的损失函数:
展开代码L_DPO(θ) = -E_{(x,y_w,y_l)} [ log σ(β(log(π_θ(y_w|x)) - log(π_θ(y_l|x)) - log(π_ref(y_w|x)) + log(π_ref(y_l|x)))) ]
其中:
python展开代码# src/llamafactory/train/dpo/trainer.py
def compute_preference_loss(
    self,
    policy_chosen_logps: "torch.Tensor",
    policy_rejected_logps: "torch.Tensor",
    reference_chosen_logps: Optional["torch.Tensor"],
    reference_rejected_logps: Optional["torch.Tensor"],
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
    """Compute loss for preference learning."""
    if not self.finetuning_args.use_ref_model:
        if self.loss_type == "orpo":
            losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
        elif self.loss_type == "simpo":
            losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
        else:
            raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
        
        chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
        rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
    else:
        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
        )
    
    return losses, chosen_rewards, rejected_rewards
项目还实现了几种 DPO 的变种:
python展开代码def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
    """Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
    log_odds = (chosen_logps - rejected_logps) - (
        torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
    )
    sft_loss = -chosen_logps
    odds_ratio_loss = -F.logsigmoid(log_odds)
    orpo_loss = sft_loss + self.beta * odds_ratio_loss
    return orpo_loss
python展开代码def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
    """Compute SimPO loss for batched log probabilities of the policy model."""
    pi_logratios = chosen_logps - rejected_logps
    gamma_logratios = self.simpo_gamma / self.beta
    logits = pi_logratios - gamma_logratios
    simpo_loss = -F.logsigmoid(self.beta * logits)
    return simpo_loss
KTO 是一种基于 K-Lipschitz 优化目标的强化学习方法。
KTO 主要通过以下参数控制训练过程:
kto_chosen_weight:首选回答的权重因子kto_rejected_weight:拒绝回答的权重因子python展开代码# src/llamafactory/train/kto/trainer.py
def get_batch_loss_metrics(
    self,
    model: "PreTrainedModel",
    batch: dict[str, "torch.Tensor"],
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
    # ...
    (
        policy_chosen_logps,
        policy_rejected_logps,
        policy_chosen_logits,
        policy_rejected_logits,
        policy_kl_logps,
        policy_chosen_logps_avg,
    ) = self.concatenated_forward(model, batch)
    reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
        model, batch
    )
    losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        policy_kl_logps,
        reference_chosen_logps,
        reference_rejected_logps,
        reference_kl_logps,
    )
    losses = losses.nanmean()
    # ...
| 方法 | 特点 | 优势 | 局限性 | 
|---|---|---|---|
| PPO | 标准RLHF方法,使用显式奖励模型 | 成熟稳定,理论基础扎实 | 训练复杂,需要额外的奖励模型 | 
| DPO | 直接从偏好数据学习,不需显式奖励 | 实现简单,训练高效 | 依赖高质量的偏好数据对 | 
| KTO | 引入K-Lipschitz约束,增强训练稳定性 | 训练更稳定,收敛更快 | 参数调优较为复杂 | 
这些方法在LLaMA-Factory中都有完整实现,可以针对不同场景和需求选择合适的训练方法。


本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!