展开代码实际batch_size = per_device_train_batch_size × gradient_accumulation_steps × 设备数量 您的设置: - per_device_train_batch_size = 12 - gradient_accumulation_steps = 2 - 假设8张卡:实际batch_size = 12 × 2 × 8 = 192
python展开代码# gradient_accumulation_steps = 1 (默认)
for batch in dataloader:
    loss = model(batch)
    loss.backward()          # 每次都通信
    optimizer.step()         # 每次都同步梯度
    
# gradient_accumulation_steps = 4
for i in range(4):
    loss = model(batch[i])
    loss.backward()          # 只累积,不通信
optimizer.step()             # 4次累积后才通信一次
gradient_accumulation_steps 个批次才同步一次

本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!