直接偏好优化(Direct Preference Optimization, DPO)是一种用于语言模型对齐的算法,由Rafailov等人在2023年提出,作为强化学习人类反馈(RLHF)的替代方案。DPO的目标与RLHF相同:使语言模型的输出更好地符合人类偏好,但DPO通过简化流程,直接从人类偏好数据中优化模型,无需单独的奖励模型和复杂的强化学习过程。
为什么需要DPO?
核心区别:
适用场景:
DPO的理论基础来自于Bradley-Terry模型,该模型用于描述对两个选项的偏好概率。
当我们有两个回答y₁和y₂时,人类偏好y₁而非y₂的概率可以建模为:
p*(y₁ ≻ y₂ | x) = σ(r*(x, y₁) - r*(x, y₂))
其中:
DPO的关键洞见是,我们可以将奖励函数r表示为最优策略π和参考策略πref的函数:
r*(x, y) = β log(π*(y|x)/πref(y|x)) + β log Z(x)
其中:
将这个公式代入Bradley-Terry模型并化简,我们得到:
p*(y₁ ≻ y₂ | x) = σ(β log(π*(y₁|x)/πref(y₁|x)) - β log(π*(y₂|x)/πref(y₂|x)))
这个公式中的Z(x)项被消去了,使计算变得可行。
基于上述模型,DPO的损失函数为:
L_DPO(πθ; πref) = -E_(yw,yl,x)~D[log(σ(β(log(πθ(yw|x)/πθ(yl|x)) - log(πref(yw|x)/πref(yl|x)))))]
其中:
补充解释:
以下是DPO损失函数的PyTorch实现:
python展开代码import torch
import torch.nn.functional as F
def dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    beta: float = 0.1,
) -> torch.Tensor:
    """
    计算DPO损失
    
    参数:
        policy_chosen_logps: 策略模型对偏好回答的对数概率
        policy_rejected_logps: 策略模型对非偏好回答的对数概率
        reference_chosen_logps: 参考模型对偏好回答的对数概率
        reference_rejected_logps: 参考模型对非偏好回答的对数概率
        beta: 正则化参数,控制KL散度的强度
        
    返回:
        dpo_loss: DPO损失值
    """
    # 计算策略模型和参考模型之间的对数概率比率
    policy_chosen_logps_ratio = policy_chosen_logps - reference_chosen_logps
    policy_rejected_logps_ratio = policy_rejected_logps - reference_rejected_logps
    
    # 计算DPO损失
    logits = beta * (policy_chosen_logps_ratio - policy_rejected_logps_ratio)
    losses = -F.logsigmoid(logits)
    
    return losses.mean()
# 在训练过程中的使用
def train_step(
    policy_model,
    reference_model,
    batch,
    optimizer,
    beta=0.1
):
    # 冻结参考模型
    for param in reference_model.parameters():
        param.requires_grad = False
    
    # 获取输入和输出
    prompts = batch["prompts"]
    chosen_responses = batch["chosen_responses"]
    rejected_responses = batch["rejected_responses"]
    
    # 计算策略模型的对数概率
    policy_chosen_logps = compute_logprobs(policy_model, prompts, chosen_responses)
    policy_rejected_logps = compute_logprobs(policy_model, prompts, rejected_responses)
    
    # 计算参考模型的对数概率
    with torch.no_grad():
        reference_chosen_logps = compute_logprobs(reference_model, prompts, chosen_responses)
        reference_rejected_logps = compute_logprobs(reference_model, prompts, rejected_responses)
    
    # 计算DPO损失
    loss = dpo_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        reference_chosen_logps,
        reference_rejected_logps,
        beta
    )
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()
def compute_logprobs(model, prompts, responses):
    """计算给定提示和回复的对数概率"""
    # 这里的具体实现取决于模型架构
    # 对于自回归语言模型,通常需要计算每个token的对数概率并求和
    # ...
    
    return logprobs
代码解读:
compute_logprobs函数的具体实现需要根据模型架构调整,例如对于GPT类模型,可以通过softmax和交叉熵计算每个token的对数概率。DPO算法已被用于多个开源语言模型的训练中,包括:
实际案例:
DPO作为RLHF的替代方案,通过直接从人类偏好数据中学习,避免了奖励建模和强化学习的复杂性。它的主要优势包括:
未来展望:
DPO的出现标志着语言模型对齐技术的重要发展,为更高效地训练符合人类偏好的语言模型提供了新的方向。


本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!