Menu

  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay
  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay

人工智能知识 - 编程(二)

18
Apr
2026

人工智能知识 - 编程(二)

By Alex
/ in AI
0 Comments

这一篇承接人工智能知识 - 编程(一)。前一篇已经梳理 AI 训练与推理编程的横向工程栈;本篇进入重点框架详解与代码精读,集中处理 PyTorch、Transformers、PEFT、语言模型强化学习、OpenRLHF、verl、DeepSpeed、vLLM,以及典型开源代码的逐行解读。

PyTorch 详解

PyTorch 是训练与推理编程栈的“最底层可控面”:Tensor、device、autograd、 nn.Module、数据加载、序列化、编译与分布式训练都在这一层完成。上层框架可以隐藏细节,但当你需要排查显存、吞吐、梯度同步、checkpoint 恢复、算子不确定性时,最终必须回到 PyTorch 的对象模型与 API 语义。

安装矩阵与快速验证

PyTorch 的安装需要同时匹配三件事:Python 版本、操作系统、计算平台(CPU/CUDA/ROCm/MPS)。实际工程里最稳妥的策略是:用官方安装页的选择器生成命令,然后将该命令固化到你的环境脚本或镜像构建中。

目标平台 常用安装方式(示例) 工程备注
CPU(Linux/macOS/Windows)
Shell
1
pip install -U torch torchvision
CPU-only 适合开发与单测;性能调优与显存问题需要在目标 GPU 上复现。
CUDA(NVIDIA GPU)
Shell
1
2
# 以官方安装页生成的命令为准;典型形态如下
pip install -U torch torchvision --index-url https://download.pytorch.org/whl/cu126
CUDA wheel 与机器驱动/运行时要匹配;多机训练应当在镜像层固定 CUDA 与 PyTorch 组合。
ROCm(AMD GPU)
Shell
1
2
# 以官方安装页生成的命令为准(ROCm 版本需匹配系统栈)
pip install -U torch torchvision --index-url https://download.pytorch.org/whl/rocm
ROCm 生态对内核/驱动版本更敏感,建议使用官方/社区维护的容器基镜像。
MPS(Apple Silicon)
Shell
1
pip install -U torch torchvision
设备为 mps;算子覆盖度与性能特征与 CUDA 不同。

最小验证覆盖三个断言:版本可读、Tensor 可算、目标加速器可见。

verify_torch.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
import torch
 
# 先确认当前导入到的 PyTorch 版本是否符合预期。
print("torch:", torch.__version__)
# 加速器可见性决定后续模型应落到 CUDA、MPS 还是 CPU。
print("cuda_available:", torch.cuda.is_available())
print("mps_available:", hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
 
# 最后做一次最小矩阵运算,确认基础张量路径可以正常执行。
x = torch.randn(2, 3)
y = x @ x.T
print("ok:", y.shape)
Tensor、dtype 与 device

Tensor 的关键元信息是:形状(shape)、数值类型(dtype)、设备(device)、以及是否参与梯度(requires_grad)。训练代码里常见 bug 本质都是“不匹配”:输入和参数不在同一 device、label dtype 错、view/reshape 造成非 contiguous 导致算子退化,或无意间把需要梯度的张量带入无梯度区间。

创建、迁移与布局

命令/API/函数
torch.tensor

说明
从 Python 对象创建张量(会拷贝)

示例

Python
1
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)

命令/API/函数
torch.as_tensor

说明
尽量不拷贝地包装已有数据

示例

Python
1
2
3
import numpy as np
arr = np.zeros((2, 3), dtype=np.float32)
x = torch.as_tensor(arr)  # 可能与 numpy 共享内存

命令/API/函数
torch.from_numpy

说明
从 numpy 创建(共享内存)

示例

Python
1
x = torch.from_numpy(arr)  # 修改一侧会影响另一侧

命令/API/函数
Tensor.to

说明
迁移 device/dtype(训练最常用)

示例

Python
1
2
device = torch.device("cuda", 0)  # or "cpu"/"mps"
x = x.to(device=device, dtype=torch.bfloat16)

命令/API/函数
Tensor.permute / Tensor.transpose

说明
重排维度顺序;图像常用于 NCHW/NHWC 互换,序列常用于 batch-first/seq-first 互换。

示例

Python
1
x = x.permute(0, 3, 1, 2)  # NHWC -> NCHW

命令/API/函数
Tensor.view / Tensor.reshape / Tensor.flatten

说明
改张量形状; view 更强调复用现有内存布局, reshape 则在必要时自动 materialize。

示例

Python
1
y = x.reshape(x.size(0), -1)

命令/API/函数
Tensor.unsqueeze / Tensor.squeeze

说明
插入或删除长度为 1 的维度,常见于 batch 维、head 维和 channel 维补齐。

示例

Python
1
2
x = x.unsqueeze(0)
x = x.squeeze(0)

命令/API/函数
Tensor.expand / Tensor.repeat

说明
expand 走广播语义,尽量不复制; repeat 会真实复制数据。

示例

Python
1
mask = mask.unsqueeze(0).expand(batch_size, -1)

命令/API/函数
Tensor.contiguous

说明
把非连续内存布局变成连续

示例

Python
1
2
x = x.permute(0, 2, 1)     # 可能变成非 contiguous
x = x.contiguous()        # 需要时显式转回

命令/API/函数
Tensor.is_contiguous / channels_last

说明
检测当前内存布局,或显式切到 channels_last memory format;视觉模型优化时很常见。

示例

Python
1
2
x = x.contiguous(memory_format=torch.channels_last)
print(x.is_contiguous(memory_format=torch.channels_last))
布局缩写、memory format 与 NumPy 互操作

PyTorch 里最容易被误解的是“轴顺序”和“memory format”并非同一个概念。 NCHW、 NHWC 描述语义上的维度顺序; contiguous、 channels_last 描述底层 stride 是否符合某种内存访问模式。

视觉模型里常见的缩写与 NumPy 一致:

  • CHW / HWC:单张图像。
  • NCHW / NHWC:批量图像。
  • channels_first / channels_last:通道维放前面还是后面。

当 NumPy 数组通过 torch.from_numpy 或 torch.as_tensor 进入 PyTorch 时,除了 shape/dtype,还要警惕底层 stride。最典型的坑是:

  • NumPy 的 np.flip、 [::-1] 之类操作,可能产生负 stride 视图。
  • 这类数组在某些 PyTorch 路径里不能直接接收,或会触发额外 materialize。
  • 跨框架前,通常用 np.ascontiguousarray 或显式 copy 把布局整理干净。
Python
1
2
3
arr = np.flip(arr, axis=1)            # 可能变成负 stride 视图
arr = np.ascontiguousarray(arr)       # 跨到 PyTorch/ORT 之前先整理成稳定布局
x = torch.from_numpy(arr)
多设备下的 device 选择

单机多卡训练通常按“每进程绑定一张卡”的方式组织。绑定的核心动作是:在进程启动后立即 torch.cuda.set_device(local_rank),并确保模型与 batch 都迁移到 cuda:local_rank。

Python
1
2
3
4
5
6
7
8
import os
import torch
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
# 当前进程只绑定一张卡;后续模型和 batch 都必须迁到同一个 device。
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
model = model.to(device)
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
autograd:梯度模式与反向图

PyTorch 的 autograd 是胶带式自动求导:前向执行时记录算子与中间结果,反向时从标量 loss 回传梯度到叶子张量(leaf tensors)。工程上需要明确三类“梯度模式”:训练(需要梯度)、评估(无梯度)、推理(更强的 inference mode)。

训练:backward 与清梯度

一个稳定的训练 step 通常遵循固定模板:清梯度、前向、算 loss、反向、(可选)裁剪、优化器 step。清梯度推荐使用 set_to_none=True,这会让 PyTorch 用 None 表示“没有梯度”,减少写零开销。

Python
1
2
3
4
5
optimizer.zero_grad(set_to_none=True)
loss = model(**batch).loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
评估与推理:no_grad 与 inference_mode

model.eval() 只切换模块行为(例如 dropout/batchnorm),不影响 autograd。关闭梯度需要显式进入无梯度上下文:

  • torch.no_grad():关闭反向图记录,适用于评估。
  • torch.inference_mode():更激进的推理模式,额外禁用若干 autograd 相关开销;它不会自动调用 model.eval()。
Python
1
2
3
model.eval()
with torch.inference_mode():
    logits = model(x)
autograd.grad:只要梯度、不做参数更新

torch.autograd.grad 在实现自定义优化、梯度惩罚、或需要显式控制梯度张量生命周期时更直接。

Python
1
2
3
4
5
6
import torch
x = torch.randn(4, requires_grad=True)
y = (x ** 2).sum()
# autograd.grad 直接返回梯度张量,不会像 backward 那样把结果顺手写进 x.grad。
gx, = torch.autograd.grad(y, x, create_graph=False)
print(gx)
nn.Module、参数注册与 state_dict

nn.Module 提供两类关键能力:组织子模块并注册参数/缓冲区;提供可序列化的 state_dict,用于保存/恢复训练状态和做 warmstart。

参数(Parameter)与缓冲区(Buffer)

可训练权重应当是 nn.Parameter 或由标准层(Linear/Conv/Embedding 等)创建;非训练但需要随模型保存的状态(例如 batchnorm 的 running_mean)应注册为 buffer。

Python
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn
 
class Toy(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(16, 16)
        # buffer 会进入 state_dict,但不会被优化器更新。
        self.register_buffer("scale", torch.tensor(1.0), persistent=True)
    def forward(self, x):
        return self.proj(x) * self.scale

buffer 是否进入 state_dict 由 persistent 决定:非持久 buffer 不会被保存,这常用于缓存中间结果或仅运行期有效的状态。

state_dict:保存/加载的工程契约

state_dict() 返回一个 Python dict,包含参数与持久化 buffer。加载时常用两种策略:

  • 严格恢复:结构完全一致,使用默认 strict=True。
  • warmstart:允许缺键/多键,使用 strict=False,并显式检查 missing/unexpected keys。
Python
1
2
3
4
state = torch.load("model.pt", map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(state, strict=False)
print("missing:", missing)
print("unexpected:", unexpected)

如果 checkpoint 来自 DDP/FSDP 包装后的模型,键名前缀经常会多出 module.。不要手工重写整个 dict;PyTorch 已提供了前缀清理工具。

Python
1
2
3
4
5
6
7
8
9
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
 
# 先在 CPU 上加载,避免设备不匹配把问题复杂化
state = torch.load("model.pt", map_location="cpu", weights_only=True)
consume_prefix_in_state_dict_if_present(
    state,
    prefix="module.",  # DDP 最常见的键名前缀;清掉后就能按“裸模型”的参数名恢复
)
missing, unexpected = model.load_state_dict(state, strict=False)
常用API

命令/API/函数
model.train()

说明
切到训练模式,启用 dropout、BatchNorm 更新等训练态行为。

示例

Python
1
model.train()

命令/API/函数
model.eval()

说明
切到评估模式,冻结 dropout/BatchNorm 的训练态分支;它不等价于关闭梯度。

示例

Python
1
model.eval()

命令/API/函数
model.parameters()

说明
返回优化器应更新的参数迭代器,多参数组通常从这里拆分 weight decay 或 learning rate。

示例

Python
1
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

命令/API/函数
model.buffers()

说明
遍历 buffer,适合检查 BatchNorm 统计量、EMA 阴影权重或运行期缓存状态。

示例

Python
1
2
for buf in model.buffers():
    print(buf.shape)

命令/API/函数
model.state_dict()

说明
导出参数与持久化 buffer 的字典,是保存 checkpoint 和部署权重包的标准契约。

示例

Python
1
state = model.state_dict()

命令/API/函数
model.load_state_dict(...)

说明
把 checkpoint 恢复到当前模块,warmstart 时应检查 missing/unexpected keys。

示例

Python
1
missing, unexpected = model.load_state_dict(state, strict=False)
数据加载:Dataset / DataLoader

训练吞吐的瓶颈经常不在 GPU,而在数据管线:解码、tokenize、增强、CPU 到 GPU 拷贝、以及 DataLoader 的多进程调度。DataLoader 的可调参数很多,但最关键的是:Dataset 类型(map-style/iterable-style)、worker 并发、pin memory、以及 batch 组装(collate)。

Map-style vs Iterable-style
map_style_dataset.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, items):
        # map-style dataset 的核心是“可随机索引”,适合本地文件、表格或已离线切好的样本集。
        self.items = items
 
    def __len__(self):
        # DataLoader 会用长度信息估算 epoch 步数、进度条和 DistributedSampler 切分范围。
        return len(self.items)
 
    def __getitem__(self, idx):
        # __getitem__ 只负责取单条样本,不在这里做 batch 逻辑。
        x, y = self.items[idx]
        return {"x": x, "y": y}

iterable_dataset.py
Python
1
2
3
4
5
6
7
8
from torch.utils.data import IterableDataset
 
class StreamDataset(IterableDataset):
    def __iter__(self):
        # iterable-style dataset 不要求随机索引,适合消息队列、对象存储分片或数据库游标。
        for i in range(1000000):
            # 每次 yield 一条样本,DataLoader 会继续负责多 worker 和 batch 拼接。
            yield {"x": i}
DataLoader 参数(工程常用)

DataLoader 的构造参数是你调吞吐的第一现场: num_workers 决定 CPU 并发、 pin_memory + non_blocking=True 影响 H2D 拷贝、 prefetch_factor / persistent_workers 影响 worker 生命周期与预取深度。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from torch.utils.data import DataLoader
 
# 这些参数共同决定 CPU 侧吞吐、预取深度和 H2D 拷贝效率。
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
    drop_last=True,
)

将 batch 移动到 GPU 时,pin memory 结合 non_blocking=True 才能发挥异步拷贝效果。

Python
1
2
def to_device(batch, device):
    return {k: v.to(device, non_blocking=True) for k, v in batch.items()}
Sampler / BatchSampler:把“取样顺序”和“组 batch 规则”拆开

shuffle=True 只覆盖了最简单的“随机读样本”场景。真实训练脚本经常需要显式控制样本顺序、顺序是否可复现、以及 batch 是否按某种结构聚合。此时更稳的做法是把“采样器(Sampler)”与“批采样器(BatchSampler)”分开表达。

命令/API/函数
RandomSampler

说明
按随机顺序产出单条样本索引,适合单机训练、可复现实验或需要自定义重采样逻辑的场景。它表达的是“索引顺序”,而非 batch 结构。

示例

Python
1
2
3
4
5
6
7
8
9
10
from torch.utils.data import DataLoader, RandomSampler
 
# 把“随机读索引”显式化,后续更容易替换成带权重或分布式 sampler
sampler = RandomSampler(dataset)
loader = DataLoader(
    dataset,
    batch_size=32,   # 这里仍由 DataLoader 负责每 32 个索引拼成一个 batch
    sampler=sampler, # 一旦显式传 sampler,就不要再同时写 shuffle=True,避免语义冲突
    num_workers=8,
)

命令/API/函数
SequentialSampler

说明
按数据集原始顺序读取样本,适合验证集、离线导出、对齐原始文件顺序的 debug、以及需要稳定回放某段数据的问题排查。

示例

Python
1
2
3
4
5
6
7
8
9
10
from torch.utils.data import DataLoader, SequentialSampler
 
# 验证或导出阶段通常更看重顺序稳定,而非随机打散
sampler = SequentialSampler(eval_dataset)
loader = DataLoader(
    eval_dataset,
    batch_size=64,
    sampler=sampler,
    num_workers=4,
)

命令/API/函数
BatchSampler

说明
先决定“这一批该由哪些索引组成”,再把整批索引交给 DataLoader。它适合长度分桶、按图像尺寸分组、或任何“batch 规则比单样本顺序更重要”的任务。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
 
base_sampler = RandomSampler(dataset)  # 单样本顺序先交给基础 sampler 决定
batch_sampler = BatchSampler(
    base_sampler,
    batch_size=32,    # 这里定义“每次吐出一组 32 个索引”
    drop_last=True,   # 训练时常丢掉尾部不满批,避免 BN/张量并行尺寸抖动
)
loader = DataLoader(
    dataset,
    # 使用 batch_sampler 后,不再传 batch_size/shuffle/sampler
    batch_sampler=batch_sampler,
    num_workers=8,
)

当样本长度或图像尺寸差异很大时,很多开源训练仓库会在 BatchSampler 上再包一层“分桶/分组”策略,让同一个 batch 内的样本更相似,从而减少 padding 浪费与动态 shape 带来的 kernel 抖动。

分布式训练的数据切分(DistributedSampler)

DDP 下每个进程应读取不同数据子集。map-style 数据集通常配合 DistributedSampler,并在每个 epoch 调用 set_epoch 让 shuffle 可复现。

Python
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data.distributed import DistributedSampler
 
# sampler 负责给每个 rank 分到互不重叠的数据子集。
sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
# DataLoader 不再自己 shuffle,改为交给 sampler 控制全局顺序。
loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=8, pin_memory=True)
 
for epoch in range(num_epochs):
    # 每个 epoch 更新随机种子,确保所有 rank 对同一轮 shuffle 的理解一致。
    sampler.set_epoch(epoch)
    for batch in loader:
        ...
AMP:混合精度的工程写法

混合精度训练通常用 torch.amp.autocast 与 torch.amp.GradScaler 组合。旧的 torch.cuda.amp.autocast / torch.cpu.amp.autocast 已逐步迁移到统一入口。

amp_step.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
 
# 显式绑定到 CUDA 设备;多卡时这里通常来自 LOCAL_RANK。
device = torch.device("cuda", 0)
# GradScaler 负责处理缩放、反缩放和溢出检测。
scaler = torch.amp.GradScaler("cuda")
 
# 模型和优化器在 AMP 外层初始化,避免每个 step 重复创建对象。
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
 
for batch in loader:
    # 先把 batch 移到目标设备,配合 pin_memory + non_blocking 提高拷贝效率。
    batch = to_device(batch, device)
 
    # set_to_none=True 可以减少写零开销,并让未参与反向的参数保留为 None。
    optimizer.zero_grad(set_to_none=True)
    with torch.amp.autocast("cuda", dtype=torch.bfloat16):
        # 前向和 loss 在 autocast 内执行,算子会按数值安全规则自动选精度。
        loss = model(**batch).loss
 
    # 先对缩放后的 loss 反向,再在 step 前完成 unscale 与裁剪。
    scaler.scale(loss).backward()
    # 反缩放后再做梯度裁剪,否则阈值会失真。
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # step/update 组合会自动跳过溢出 step,并更新下一轮缩放因子。
    scaler.step(optimizer)
    scaler.update()
Checkpoint:保存、恢复与安全加载

训练脚本里 checkpoint 的工程目标是两件事:可恢复(resume 时学习率/AMP/随机性都对齐),以及可复用(用于推理或 warmstart)。推荐把 checkpoint 组织成一个 dict:模型 state、优化器 state、调度器 state、AMP scaler state、当前步数/epoch、以及必要的 RNG 状态。

通用 checkpoint 结构
checkpoint_io.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch
 
def save_checkpoint(path, *, model, optimizer, scheduler=None, scaler=None, step=0, epoch=0):
    # 把“恢复训练所需的全部运行状态”一次性固化到一个 dict。
    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler else None,
        "scaler": scaler.state_dict() if scaler else None,
        "step": int(step),
        "epoch": int(epoch),
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    }
 
    # 先写临时文件,再原子替换正式文件,避免崩溃时留下半截 checkpoint。
    tmp = path + ".tmp"
    torch.save(ckpt, tmp)
    os.replace(tmp, path)  # 原子替换,避免写到一半崩溃留下坏文件
 
def load_checkpoint(path, *, model, optimizer, scheduler=None, scaler=None, map_location="cpu"):
    # weights_only=True 把反序列化收窄到更安全的状态字典类型集合。
    ckpt = torch.load(path, map_location=map_location, weights_only=True)
    # 先恢复模型和优化器,再恢复可选组件。
    model.load_state_dict(ckpt["model"], strict=True)
    optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler and ckpt.get("scheduler"):
        scheduler.load_state_dict(ckpt["scheduler"])
    if scaler and ckpt.get("scaler"):
        scaler.load_state_dict(ckpt["scaler"])
 
    # 把步数和 epoch 返回给外层训练循环,继续从正确位置接着跑。
    step = int(ckpt.get("step", 0))
    epoch = int(ckpt.get("epoch", 0))
    return step, epoch
torch.load 的安全边界(weights_only)

torch.load 基于 pickle 反序列化,不能加载不可信来源的文件。加载权重/状态字典时优先使用 weights_only=True,把反序列化限定在 state_dict 等常见安全类型集合内。

大 checkpoint 在 CPU 内存紧张的环境里还常配合 mmap=True 使用。它的工程意义是尽量避免一次性把整个文件完整拷进用户态内存,从而降低加载峰值。

Python
1
2
3
4
5
6
state = torch.load(
    "model.pt",
    map_location="cpu", # 先在 CPU 侧安全落稳,再把真正需要的张量搬到目标设备
    weights_only=True,  # 把反序列化边界收窄到常见权重/状态类型
    mmap=True,          # 以内存映射方式读取大文件,常用于超大 checkpoint 的低峰值加载
)
Distributed Checkpointing(DCP)

当模型和优化器状态已经是分布式形态时,把所有 rank 的状态先 gather 成单文件再保存,I/O 与 CPU 峰值都会迅速变差。PyTorch 的 torch.distributed.checkpoint 提供了面向分布式训练的 checkpoint 读写接口:每个 rank 写自己那一份分片,恢复时再按当前并行拓扑装回去。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.distributed.checkpoint import (
    DefaultLoadPlanner,
    FileSystemReader,
    FileSystemWriter,
    load,
    save,
)
 
state = {
    "model": model.state_dict(),         # 当前 rank 负责把自己持有的模型状态写进 DCP 目录
    # 优化器状态一起保存,断点续训时动量与学习率轨迹才能对齐
    "optimizer": optimizer.state_dict(),
}
 
save(
    state,
    # DCP 会写元数据和分片文件目录,而非单个 .pt 文件
    storage_writer=FileSystemWriter("ckpt_dcp"),
)
 
state = {
    # 恢复前先准备“接收容器”;load 会把 checkpoint 内容写回这些对象
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
load(
    state,
    storage_reader=FileSystemReader("ckpt_dcp"),
    planner=DefaultLoadPlanner(
        # 默认要求 checkpoint 与当前状态完整匹配;warmstart 才考虑放宽
        allow_partial_load=False,
    ),
)

如果 checkpoint 写盘本身会卡住训练步,DCP 还提供 async_save。工程上通常要配合“前一次异步保存未完成前不再发起下一次保存”的节流策略,避免后台 I/O 线程堆积。

统一 state_dict API:先抽象“该保存什么”

更复杂的 DDP/FSDP 脚本里,推荐先用统一 state_dict API 把“该保存什么状态”整理出来,再交给 DCP 写盘。这样做的好处是:checkpoint 语义先被建模清楚,再决定底层是单文件还是分布式目录。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
)
 
state = {
    # 不直接假设当前 model.state_dict() 的形态,改为用统一 API 取“可保存状态”。
    "model": get_model_state_dict(model),
    "optimizer": get_optimizer_state_dict(model, optimizer),
}
 
# ... 这里可以继续交给 DCP save / async_save ...
 
# 恢复时反向写回,避免调用方自己手拼不同并行策略下的状态结构。
set_model_state_dict(model, state["model"])
set_optimizer_state_dict(model, optimizer, state["optimizer"])

这层抽象尤其适合 FSDP、DTensor 或未来并行拓扑会变化的项目,因为“状态怎么表示”与“状态怎么写盘”被拆成了两步。

StateDictOptions:跨拓扑恢复前先定义状态形态

统一 state_dict API 解决了“该保存什么”,但跨拓扑恢复还要回答另一个问题:这些状态应当以什么形态被抽取出来。是完整的 full state,还是保持分片形态;是先搬到 CPU,还是直接留在设备上;是否由 rank0 先拿到完整权重,再广播给其它 rank。 StateDictOptions 就是用来定义这层恢复契约的。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions,
    get_model_state_dict,
    set_model_state_dict,
)
 
options = StateDictOptions(
    # 先抽成完整模型状态,适合跨并行拓扑 warmstart 或导出给别的系统消费。
    full_state_dict=True,
    # 先把状态落到 CPU,可降低恢复时的 GPU 峰值。
    cpu_offload=True,
    # 由 rank0 持有完整状态后再广播,常用于“单份 checkpoint 恢复到新拓扑”。
    broadcast_from_rank0=True,
)
 
state = get_model_state_dict(model, options=options)
# ... 可继续交给 DCP save / load,或做格式转换 ...
set_model_state_dict(model, state, options=options)

这类选项的意义不在于“多几个参数”,而在于把恢复语义写明白。只要训练产物可能在单卡评估、不同 GPU 数、不同 rank 布局之间流动,就应该先定义状态形态,再谈底层 checkpoint 文件怎么组织。

async_save:把写盘挪到训练步之外
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch.distributed.checkpoint import async_save
 
pending = None
 
if should_save and pending is None:
    pending = async_save(
        state,
        storage_writer=FileSystemWriter("ckpt_async"),
    )
 
# 下一次发起保存前,先确认前一次后台写盘已经结束。
if pending is not None and pending.done():
    pending.result()   # 主动抛出后台保存过程中的异常,避免默默失败
    pending = None

异步保存解决的是“step time 不想被 I/O 卡住”。它不会帮你解决磁盘空间、网络文件系统抖动或 checkpoint 保留策略,所以节流与清理机制仍然要自己设计。

torch.compile:编译加速与排错入口

torch.compile 会追踪(trace)你的 Python 代码中的张量计算并生成可优化的图。工程上它的常见收益来自两类:更少的 Python 开销、以及 Inductor 等后端生成的融合 kernel。无法追踪的代码会产生 graph break,这通常是性能损失,而非静默错误。

compile_minimal.py
Python
1
2
3
4
5
6
7
8
9
import torch
 
# 先把模型放到目标设备,再调用 compile,让捕获到的图直接面向目标后端。
model = model.to("cuda")
model = torch.compile(model)  # 最小改动:只包一次
 
# 编译和 AMP 可以叠加使用;compile 负责图级优化,autocast 负责精度路径。
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
    out = model(x)

当你需要确认 compile 到底 trace 了什么,可以打开日志来观察 traced graph(用于定位 graph break 与非预期的 Python 分支)。

Python
1
2
import torch
torch._logging.set_logs(graph_code=True)
compile + DDP 的包裹顺序

DDP 与 torch.compile 同时使用时,默认整模型路线更适合按 PyTorch 的 DDP note 来写:先包 DDP,再对 DDP 模型做 compile。这样 TorchDynamo 可以利用 DDP bucket 信息做 DDPOptimizer 相关优化,保留更好的通信-计算重叠机会。

Python
1
2
3
4
5
6
7
8
9
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
 
# 先把真实模型放到当前 rank 对应设备。
model = MyModel().to(device)
# 先包 DDP,让编译器能感知到梯度 bucket 与分布式外壳。
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# 再 compile 整个 DDP 模型,走官方 DDP note 更偏向的优化路径。
model = torch.compile(model)

如果你核心是只想对某个稳定子模块做区域化编译,那么“先 compile 子模块,再进入 DDP/FSDP 体系”也可能成立。关键不在背口诀,而在先明确你追求的是:默认整模型吞吐,还是对子模块做精细化图控制。

graph break:哪些代码会让 compile 失去收益

graph break 的含义是:编译器在某一段 Python 代码处无法继续追踪张量图,只能先把当前已捕获部分编译掉,回到普通 Python 执行,再从后面重新开始追踪。结果通常核心是“少编了一大段”,表现为速度没有明显提升甚至更慢。

最常见的 graph break 来源包括:

  • 前向里混入与张量无关但频繁执行的 Python 控制流,例如复杂字典操作、字符串拼接、调试打印。
  • 根据张量值做 Python 分支,而非继续保留在张量图里。
  • 每个 step 都改变形状或结构,导致已经编译过的图难以复用。
Python
1
2
3
4
5
6
def forward(self, x):
    x = self.proj(x)
    # 这类 Python 打印本身不一定报错,但会让热点路径更难形成稳定图。
    if self.debug:
        print(x.shape)
    return self.head(x)

排障顺序通常是:

  1. 先关闭 torch.compile,确认 eager 路径本身正确。
  2. 打开图日志或 profiler,确认 break 集中在什么位置。
  3. 把热点前向中的 Python 杂质移出,或把不稳定的小段单独保留为 eager。

对于结构很稳定但某几段特别复杂的模型,可以只编译热点子模块,而非整模型“一把包住”。这类做法本质上属于区域化编译(regional compilation):把最值得优化的几段先稳定下来,再决定是否继续扩大编译范围。

分布式训练:torchrun + DDP 最小可用形态

DDP 的基本形态是“每进程一份模型副本 + 反向时梯度同步”。启动建议使用 torchrun,它会为每个进程设置 RANK/ LOCAL_RANK/ WORLD_SIZE 等环境变量,并负责 rendezvous。

启动命令(单机多卡)
Shell
1
torchrun --standalone --nproc_per_node=8 train_ddp.py --config config.yaml
DDP 训练脚本骨架(可直接复用)
train_ddp.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
 
def ddp_setup():
    # 由 torchrun 提供 RANK / LOCAL_RANK / WORLD_SIZE,NCCL 负责 GPU 间通信。
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    # 每个进程只绑定一张卡,避免误把多个进程都放到 cuda:0。
    torch.cuda.set_device(local_rank)
    return local_rank
 
def is_rank0():
    # 只有 rank0 负责写 checkpoint / 打印主日志,避免多进程重复写文件。
    return int(os.environ.get("RANK", "0")) == 0
 
def main():
    # 完成分布式初始化,并构造当前进程对应的 device。
    local_rank = ddp_setup()
    device = torch.device("cuda", local_rank)
 
    # 每个进程各自持有一份模型副本;DDP 会在反向阶段同步梯度。
    model = MyModel(...).to(device)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=True)
 
    # sampler 负责按 rank 切数据,否则多个进程会重复读到同一批样本。
    dataset = MyDataset(...)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
    loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=8, pin_memory=True)
 
    # 优化器和 AMP scaler 都在 DDP 包装后初始化,确保参数引用一致。
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.amp.GradScaler("cuda")
 
    for epoch in range(10):
        # 每轮都更新 sampler 随机种子,确保所有 rank 的 shuffle 同步。
        sampler.set_epoch(epoch)
        model.train()
        for batch in loader:
            # batch 在进入前向前搬到本进程绑定的 GPU。
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
 
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                loss = model(**batch).loss
 
            # backward 时 DDP 会自动做梯度 all-reduce。
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
 
        if is_rank0():
            # 保存时取出原始 module,避免把 DDP 包装层一并写进权重结构。
            torch.save(model.module.state_dict(), f"model-ep{epoch}.pt")
 
    # 训练结束后主动销毁进程组,释放 NCCL 资源。
    dist.destroy_process_group()
 
if __name__ == "__main__":
    main()
DDP 构造参数:真正高频的几个开关

DDP 很少只靠 DDP(model, device_ids=[...]) 就结束。真实工程里最常被反复调整的是下面几项,它们直接关系到“会不会多做通信”“是否能兼容动态分支”“显存里梯度长什么样”。

命令/API/函数
broadcast_buffers

说明
控制 rank0 的 buffer 是否在前向时广播到其它 rank。典型 buffer 包括 BatchNorm 的 running mean/var 这类不参与梯度更新、但会影响推理语义的状态。

示例

Python
1
2
3
4
5
6
7
model = DDP(
    model,
    device_ids=[local_rank],
    output_device=local_rank,
    # 带 BatchNorm 或其它运行时 buffer 的模型通常保留 True,避免各 rank 状态漂移。
    broadcast_buffers=True,
)

命令/API/函数
find_unused_parameters

说明
让 DDP 在反向图里查找“这一步没参与 loss 的参数”。只在动态图、MoE、条件分支这类场景下启用;结构固定的模型尽量保持关闭,避免额外遍历和同步开销。

示例

Python
1
2
3
4
5
6
model = DDP(
    model,
    device_ids=[local_rank],
    # 只有确实存在条件分支/按样本走不同子图时才打开。
    find_unused_parameters=True,
)

命令/API/函数
static_graph

说明
告诉 DDP 每一步参与反向的参数集合与图结构都稳定不变。对固定结构训练,这能减少内部图分析开销,也更适合长时间稳定运行的预训练/微调任务。

示例

Python
1
2
3
4
5
6
model = DDP(
    model,
    device_ids=[local_rank],
    # 只有在前向图和参数参与关系稳定时才启用。
    static_graph=True,
)

命令/API/函数
gradient_as_bucket_view

说明
让参数梯度直接视作通信 bucket 的视图,以减少梯度副本开销。它有助于省显存,但也要求脚本不要依赖“随手对 grad 做就地奇技淫巧”的旧习惯。

示例

Python
1
2
3
4
5
6
model = DDP(
    model,
    device_ids=[local_rank],
    # 想进一步压缩梯度内存时可尝试;改梯度的自定义逻辑要先核对兼容性。
    gradient_as_bucket_view=True,
)

命令/API/函数
bucket_cap_mb

说明
控制梯度 bucket 的大小。bucket 太小会导致 all-reduce 次数变多,太大又会推迟通信启动时机;它本质上是在平衡“通信碎片化”与“通信重叠启动时机”。

示例

Python
1
2
3
4
5
6
model = DDP(
    model,
    device_ids=[local_rank],
    # 大模型通信调优时常会显式试几个桶大小,而非完全沿用默认值。
    bucket_cap_mb=50,
)

命令/API/函数
register_comm_hook

说明
为 DDP 梯度通信注册自定义 hook,例如 fp16 梯度压缩、PowerSGD 或你自己的 bucket 处理逻辑。它属于高级优化位点,适合通信已经明显成为瓶颈时再动。

示例

Python
1
2
3
4
5
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
 
model = DDP(model, device_ids=[local_rank])
# 把 bucket 梯度先做 fp16 压缩再通信,牺牲部分数值冗余换带宽。
model.register_comm_hook(state=None, hook=fp16_compress_hook)
新版 DDP 语义补充:init_sync、forward_sync_buffers、skip_all_reduce_unused_params

旧经验里常把 DDP 的关键参数概括成 broadcast_buffers、 find_unused_parameters 和 static_graph。但在新版实现里,还有三项更贴近工程语义的开关:初始化是否先同步一次完整状态、前向时是否同步 buffer、以及 unused 参数是否直接跳过 all-reduce。

Python
1
2
3
4
5
6
7
8
9
10
11
model = DDP(
    model,
    device_ids=[local_rank],
    output_device=local_rank,
    # 启动时先做一次参数/缓冲同步,保证所有 rank 从同一份初始权重起跑。
    init_sync=True,
    # 前向阶段同步运行时 buffer;对带 BatchNorm 或其它状态型 buffer 的模型更稳。
    forward_sync_buffers=True,
    # 只有当所有 rank 的 unused 参数集合恒定一致时,这个优化才安全。
    skip_all_reduce_unused_params=False,
)

这三项分别解决三类问题。 init_sync 解决“各 rank 初始状态是否真的一致”; forward_sync_buffers 解决“运行时 buffer 会不会逐步漂移”; skip_all_reduce_unused_params 解决“未参与本轮反向的参数还要不要同步”。最后这一项要格外保守,因为只要不同 rank 的 unused 参数集合不一致,就有卡死风险。

no_sync:梯度累积时避免白做 all-reduce

梯度累积下,如果仍然每个 micro-step 都让 DDP 正常反向,同步就会发生在每一次 backward() 上,前 \(N-1\) 个 micro-step 的 all-reduce 都是白做。 model.no_sync() 的作用,就是把这些中间步的通信推迟到最后一次真正需要更新前再发生。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
accum_steps = 8
 
for step, batch in enumerate(loader):
    batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
    is_last_micro = (step + 1) % accum_steps == 0
 
    if is_last_micro:
        # 最后一个 micro-step 正常反向;这一步才执行真正的梯度同步。
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            loss = model(**batch).loss / accum_steps
        scaler.scale(loss).backward()
    else:
        # 中间 micro-step 只累计本地梯度,不触发 all-reduce。
        with model.no_sync():
            with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                loss = model(**batch).loss / accum_steps
            scaler.scale(loss).backward()
 
    if is_last_micro:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

如果你的训练已经使用了高阶框架(例如 Accelerate、Lightning、DeepSpeed),需要先确认框架是否已经替你做了这件事。手工再包一层 no_sync(),容易把同步节奏搞乱。

torchrun 多机 rendezvous 参数

单机时 --standalone 足够;一旦进入多机,决定“能不能稳定拉起”的是 rendezvous 配置。关键参数包括:节点数、当前节点序号、主节点地址与端口,以及弹性重启相关设置。

Shell
1
2
3
4
5
6
7
8
torchrun \
  --nnodes=2 \
  --nproc_per_node=8 \
  --node_rank=0 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=10.0.0.1:29500 \
  --max_restarts=0 \
  train_ddp.py --config config.yaml

排障时,先确认三件事:

  • 所有节点看到的 --rdzv_endpoint 完全一致。
  • --node_rank 从 \(0\) 开始连续编号,没有重复也没有跳号。
  • 防火墙、容器网络、作业调度器没有把 rendezvous 端口挡住。
脚本组织方式(训练/推理共用)

可维护性来自“把易变部分隔离出来”:模型定义、数据定义、运行时策略(AMP/compile/DDP)、以及 I/O(checkpoint/logging)。一个简单但可扩展的组织方式如下:

1
2
3
4
5
6
7
8
9
project/
  src/
    models.py        # nn.Module 定义与构建函数
    data.py          # Dataset/Tokenizer/Collate
    train_step.py    # 单步训练逻辑(支持 AMP/compile)
    ddp.py           # 分布式初始化与 rank 工具函数
    ckpt.py          # save/load_checkpoint(含 weights_only 策略)
  train.py           # 单机/单卡入口
  train_ddp.py       # torchrun 入口
Transformers 详解

Transformers 在工程上提供了一套可组合的入口:模型与 tokenizer/processor 的加载与保存( from_pretrained/ save_pretrained)、架构无关的 Auto* 工厂、训练循环( Trainer/ TrainingArguments)、以及推理生成( generate)与对话模板(Chat Template)。这一节只讲“如何编程接入与部署落地”,不展开算法原理。

安装与最小依赖
Shell
1
2
3
4
pip install -U transformers
 
# 需要 device_map="auto" / offload / 分布式等能力时通常还需要
pip install -U accelerate
from_pretrained / save_pretrained(装载与交付)

from_pretrained 负责把“模型仓库或本地目录”解析成 Python 对象;save_pretrained 把对象序列化回一个可复用的目录。工程上把这个目录当作“可交付模型包”(artifact),它应当可被推理服务直接加载。

从 Hub 或本地目录加载
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
 
# 同一入口既能接 Hub repo id,也能接本地导出的模型目录。
model_id_or_path = "Qwen/Qwen3-0.6B"   # 也可以是 ./models/prod 这类本地目录
 
# tokenizer 和模型都从同一目录加载,避免词表版本漂移。
tok = AutoTokenizer.from_pretrained(model_id_or_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
  model_id_or_path,
  torch_dtype="auto",
  # device_map="auto" 让 accelerate 自动做设备放置和必要的 offload。
  device_map="auto",          # 需要 accelerate
)
 
# 生成前先切到 eval,并关闭梯度,避免无意义的显存开销。
model.eval()
with torch.inference_mode():
  # tokenizer 的返回值本身就是模型 forward/generate 所需的张量字典。
  out = model.generate(**tok("Hello", return_tensors="pt").to(model.device), max_new_tokens=32)
  print(tok.decode(out[0], skip_special_tokens=True))
保存到本地目录(模型包)
Python
1
2
3
4
5
6
7
8
from pathlib import Path
out_dir = Path("models/registry/model_v0001")
out_dir.mkdir(parents=True, exist_ok=True)
 
# safe_serialization=True 会把权重写成 safetensors,适合作为部署产物默认格式。
model.save_pretrained(out_dir, safe_serialization=True)  # 推荐 safetensors
# tokenizer 也必须随模型一并导出;缺词表或 special tokens 会直接破坏推理语义。
tok.save_pretrained(out_dir)
本地权重目录结构(读写约定)

Transformers 的加载逻辑依赖“目录里有哪些标准文件”。同一个目录既可以来自 Hub 下载缓存,也可以来自 save_pretrained 导出。

1
2
3
4
5
6
7
8
9
10
11
model_dir/
  config.json
  generation_config.json                # 可选:生成参数默认值
  model.safetensors                     # 或 pytorch_model.bin
  model.safetensors.index.json          # 可选:分片索引(大模型常见)
  model-00001-of-00002.safetensors      # 可选:分片权重文件
  tokenizer.json                        # fast tokenizer 常见
  tokenizer_config.json
  special_tokens_map.json
  vocab.json / merges.txt               # BPE 类 tokenizer 常见
  spiece.model                          # SentencePiece tokenizer 常见
离线加载与缓存目录

离线/内网环境的最小做法是提前把模型仓库下载到本地目录,然后用本地路径调用 from_pretrained。需要让缓存落到指定盘符时,优先设置 Hugging Face Hub 的缓存环境变量(例如 HF_HUB_CACHE / HF_HOME)。

示例:改变缓存目录
Shell
1
2
export HF_HOME=/data/hf
export HF_HUB_CACHE=/data/hf/hub

只用本地文件(不触网)
Python
1
2
tok = AutoTokenizer.from_pretrained("./model_dir", local_files_only=True)
model = AutoModelForCausalLM.from_pretrained("./model_dir", local_files_only=True)
trust_remote_code、revision 与可审计加载

当模型仓库包含自定义 Python 实现,而该实现不在 Transformers 内建模型类集合里时,加载阶段往往需要显式打开 trust_remote_code=True。这核心是在允许仓库里的 Python 代码参与本地执行,因此应同时固定 revision 或 commit,避免同一个 repo name 在不同时间拉到不同行为的代码。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoConfig, AutoModelForCausalLM
 
# 先单独拿 config,是为了把“远端代码 + revision”这类审计边界先固定下来。
cfg = AutoConfig.from_pretrained(
    "org/custom-model",
    # 固定到具体 revision,避免远端代码更新后加载行为漂移。
    revision="8d4c9d7",
    # 允许仓库里的自定义 Python 类参与实例化;
    # 只有源码已经审过、且来源可信时才打开。
    trust_remote_code=True,
)
 
model = AutoModelForCausalLM.from_pretrained(
    "org/custom-model",
    # 复用上面已经审过并固定 revision 的配置对象,
    # 避免“config 来自一个版本,权重来自另一个版本”。
    config=cfg,
    revision="8d4c9d7",      # 模型权重与配置都固定到同一提交
    trust_remote_code=True,  # 只在审计过源码的受控环境里启用
    torch_dtype="auto",      # 按仓库推荐 dtype 落地,减少手工指定精度带来的不兼容
    device_map="auto",       # 原型阶段先自动分配设备;正式部署再切到显式映射
)

线上环境更稳的做法通常是:先在受控机器上把模型拉到本地、审计并冻结目录,再由服务侧只加载本地目录而非直接联网拉取。

Auto* 家族(统一入口)

Auto* 是“按配置自动选择具体实现”的工厂。工程上把它当作跨架构的稳定入口:你不需要在代码里硬编码某个模型类名,尤其是在需要频繁替换基座模型时。

命令/API/函数
AutoConfig

说明
读取/改写模型配置(层数、rope、token id 等)

示例

Python
1
2
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")

命令/API/函数
AutoTokenizer

说明
加载 tokenizer(文本 → input_ids/attention_mask)

示例

Python
1
2
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", use_fast=True)

命令/API/函数
AutoProcessor

说明
多模态 processor(文本+图像/音频等统一预处理)

示例

Python
1
2
from transformers import AutoProcessor
proc = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

命令/API/函数
AutoModel

说明
只要 backbone 表示(不带任务头)

示例

Python
1
2
from transformers import AutoModel
m = AutoModel.from_pretrained("bert-base-uncased")

命令/API/函数
AutoModelForCausalLM

说明
Decoder-only 生成(LLM 推理/微调)

示例

Python
1
2
from transformers import AutoModelForCausalLM
m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")

命令/API/函数
AutoModelForSeq2SeqLM

说明
Encoder-Decoder 生成(翻译、摘要等)

示例

Python
1
2
from transformers import AutoModelForSeq2SeqLM
m = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")

命令/API/函数
AutoModelForSequenceClassification

说明
文本分类

示例

Python
1
2
from transformers import AutoModelForSequenceClassification
m = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

命令/API/函数
AutoModelForTokenClassification

说明
序列标注(NER/词性标注等)

示例

Python
1
2
from transformers import AutoModelForTokenClassification
m = AutoModelForTokenClassification.from_pretrained("bert-base-uncased", num_labels=9)
Tokenizer 与 Processor(输入标准化)

Tokenizer/Processor 把“原始输入”变成模型可消费的张量字典。文本模型通常用 tokenizer;视觉/语音/多模态模型往往用 processor,它内部可能组合 tokenizer + image/audio processor。

Tokenizer 的返回结构
Python
1
2
3
4
5
6
7
8
inputs = tok(
  ["a", "b"],
  padding=True,
  truncation=True,
  max_length=128,
  return_tensors="pt",
)
# inputs 通常包含:input_ids, attention_mask(以及 token_type_ids 等,视模型而定)
Processor 的典型用法(以 CLIP 为例)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
 
# processor 统一封装图像预处理和文本 tokenization,保证两路输入对齐
proc = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
# CLIP 的裸 AutoModel 输出图文表示,可继续做相似度计算
model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
 
img = Image.open(
  requests.get(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/cat.jpg",
    stream=True,
  ).raw
)  # 直接从远端流里读图,示例重点放在多模态接线而非本地文件处理
 
inputs = proc(
  text=["a photo of a cat"],  # CLIP 推理时文本提示和图像通常成对出现
  images=[img],
  return_tensors="pt",        # 返回 PyTorch tensor,才能直接喂给 model(**inputs)
  padding=True,               # 文本长度不一致时由 processor 统一补齐
)
out = model(**inputs)  # out 里包含图像与文本 embedding,可继续算相似度或做检索
Trainer / TrainingArguments(训练循环)

Trainer 把训练循环、评估、保存 checkpoint、日志与分布式协同做成统一入口。工程上最关键的是把 TrainingArguments 固化成可追溯的配置(写入 run_meta.json 或随 checkpoint 一起存档),并严格区分 “best checkpoint” 与 “last checkpoint”。

训练控制高频 API

命令/API/函数
EarlyStoppingCallback

说明
当评估指标持续不改善时提前结束训练。它依赖 eval_strategy、 metric_for_best_model 与 load_best_model_at_end 形成闭环。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
from transformers import EarlyStoppingCallback
 
callbacks = [
    EarlyStoppingCallback(
        # 连续 3 次评估都没有实质改善才停;
        # 这个数字应该与 eval_steps / eval_strategy 一起理解。
        early_stopping_patience=3,
        # 0.0 表示“只要更好一点就算改善”;
        # 若指标噪声很大,可以抬高阈值避免把抖动误判成提升。
        early_stopping_threshold=0.0,
    )
]

命令/API/函数
trainer.train(resume_from_checkpoint=...)

说明
从最近一次或指定 checkpoint 恢复训练。恢复对象包括模型权重、optimizer、scheduler 与 trainer 状态。

示例

Python
1
trainer.train(resume_from_checkpoint="out_sst2/checkpoint-1200")

命令/API/函数
trainer.push_to_hub

说明
把训练好的模型包、tokenizer 与元数据直接推送到 Hub,适合把“训练完成 → 共享/部署制品”做成固定交付动作。

示例

Python
1
trainer.push_to_hub(commit_message="ship best checkpoint")
HfArgumentParser:把训练脚本做成 dataclass CLI

Transformers 官方 example 几乎都把训练脚本写成“若干 dataclass + HfArgumentParser”。这种写法的价值核心是把命令行、JSON 配置文件与 Python 对象统一到同一套字段定义上,便于实验复现、配置审阅与批量跑任务。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from dataclasses import dataclass, field
from transformers import HfArgumentParser, TrainingArguments
 
@dataclass
class ScriptArguments:
    # 脚本自定义参数单独放这里,和 HF 标准训练参数解耦
    model_name_or_path: str = field(default="distilbert-base-uncased")
    # 让“训练入口”直接知道该加载哪份数据
    dataset_name: str = field(default="glue")
    # 同一数据集常有多个子配置,显式列出来更可复现
    dataset_config: str = field(default="sst2")
 
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
# CLI 会自动映射到 dataclass 字段,类型转换也由 parser 处理
script_args, training_args = parser.parse_args_into_dataclasses()
最小训练骨架(分类任务)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from datasets import load_dataset
from transformers import (
  AutoTokenizer,
  AutoModelForSequenceClassification,
  DataCollatorWithPadding,
  Trainer,
  TrainingArguments,
)
 
# 先取一个标准数据集,把重点放在 Trainer 的工程接入方式上。
ds = load_dataset("glue", "sst2")
tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
 
def tokenize(batch):
  # tokenize 只做输入标准化,不在这里引入 label 或 padding 逻辑。
  return tok(batch["sentence"], truncation=True)
 
# batched=True 能让 tokenizer 一次处理一批样本,减少 Python 开销。
ds = ds.map(tokenize, batched=True)
# 动态 padding 放到 collator 层,避免预处理阶段把所有样本 pad 到同一长度。
collator = DataCollatorWithPadding(tokenizer=tok)
 
# TrainingArguments 把评估、保存、日志和 best checkpoint 选择集中到一个配置对象里。
args = TrainingArguments(
  output_dir="out_sst2",
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  num_train_epochs=1,
  evaluation_strategy="steps",
  eval_steps=200,
  save_strategy="steps",
  save_steps=200,
  save_total_limit=3,
  load_best_model_at_end=True,
  metric_for_best_model="eval_loss",
  greater_is_better=False,
  report_to="none",
)
 
def compute_metrics(eval_pred):
  # Trainer 会把 logits 和 labels 打包给 compute_metrics,这里只保留最小 acc 示例。
  logits, labels = eval_pred
  preds = np.argmax(logits, axis=-1)
  return {"acc": (preds == labels).mean().item()}
 
# Trainer 把模型、数据、padding 规则和 metric 计算统一到同一个训练入口。
trainer = Trainer(
  model=model,
  args=args,
  train_dataset=ds["train"],
  eval_dataset=ds["validation"],
  tokenizer=tok,
  data_collator=collator,
  compute_metrics=compute_metrics,
)
# 调用 train() 后,Trainer 会自动接管训练循环、评估、保存和日志。
trainer.train()
官方 example 的恢复套路:get_last_checkpoint + save_metrics

Transformers 官方 example 脚本的共同特点是:它们不仅调用一次 trainer.train() 就结束,还把“发现已有 checkpoint”“决定从哪里恢复”“把 metrics 写进磁盘”“把 trainer 状态单独持久化”做成固定套路。这样训练目录才既能续跑,又能给后续回归分析留下证据。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
 
from transformers.trainer_utils import get_last_checkpoint
 
# 先看 output_dir 里是否已经存在未完成训练留下的 checkpoint。
last_checkpoint = None
if os.path.isdir(args.output_dir):
    last_checkpoint = get_last_checkpoint(args.output_dir)
 
# 用户显式指定的恢复点优先级最高;
# 否则才回退到“自动发现的最后一个 checkpoint”。
resume_ckpt = training_args.resume_from_checkpoint or last_checkpoint
 
# train_result.metrics 是训练阶段的聚合指标,包含 loss 以及 Trainer 汇总出的其它统计量。
train_result = trainer.train(resume_from_checkpoint=resume_ckpt)
metrics = train_result.metrics
# 把样本数也写进去,后续比对不同运行时才知道这些指标基于多大数据规模。
metrics["train_samples"] = len(train_dataset)
 
# log_metrics 负责打印/上报;
# save_metrics 负责把指标固化到 output_dir 下的 JSON 文件。
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
# save_state 会保存 trainer_state.json 等运行元数据,
# 其中包括 global_step、最佳 checkpoint 路径和随机状态摘要。
trainer.save_state()
# save_model 则负责导出当前模型包;它和 save_state 并非一回事。
trainer.save_model(training_args.output_dir)

save_metrics、 save_state 和 save_model 对应三种不同产物:指标、训练状态、模型制品。把它们混成“反正都保存一下”会让训练目录变得难以审计。

大评估集的显存治理:eval_on_start、eval_accumulation_steps、preprocess_logits_for_metrics

Trainer 做评估时,真正容易炸显存的环节常常是“把所有 logits 攒起来再交给 compute_metrics”。官方脚本里更稳的做法是把这三件事配合起来:训练前先做一次 sanity eval,评估阶段分批把张量搬回 CPU,并在进入 metric 计算前先把巨大的 logits 压缩成更小的统计表示。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import evaluate
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments
 
metric = evaluate.load("accuracy")
 
def my_loss_fn(outputs, labels, num_items_in_batch):
    logits = outputs["logits"]
    # reduction='sum' 后再除以真实样本数,避免最后一个小 batch 改变 loss 标尺。
    loss = F.cross_entropy(logits, labels, reduction="sum")
    return loss / num_items_in_batch
 
def preprocess_logits_for_metrics(logits, labels):
    # 评估阶段不必把整块 logits 都搬到 CPU;
    # 如果指标只看 argmax,这里先压成类别 id,可显著降低内存与通信量。
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1)
 
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    return metric.compute(predictions=preds.reshape(-1), references=labels.reshape(-1))
 
args = TrainingArguments(
    output_dir="out_eval_safe",
    eval_strategy="epoch",
    # 开训前先跑一遍评估,尽早发现标签对齐、metric 键名或数据切分问题。
    eval_on_start=True,
    # 分批把评估结果从 GPU 挪回 CPU,避免整轮评估的中间张量长期堆在显存里。
    eval_accumulation_steps=16,
)
 
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_loss_func=my_loss_fn,                   # 把损失定义显式化,适合加权/自定义任务头
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=compute_metrics,
)

这组配置的工程价值很高。 eval_on_start=True 解决的是“脚本刚启动就知道评估链路通不通”; eval_accumulation_steps 解决的是“评估张量怎么分批落回 CPU”; preprocess_logits_for_metrics 解决的是“没必要把整块 logits 都存下来”。三者配合后,大评估集上的 Trainer 稳定性会明显好很多。

collator 的分工:padding、MLM、Seq2Seq 标签处理

数据集预处理阶段负责“样本级转换”,collator 负责“把一批样本组装成可喂给模型的 batch”。这层分工非常关键:padding、mask 构造、label pad id 处理都应该放在 collator,而非硬塞进 Dataset.map。

命令/API/函数
default_data_collator

说明
几乎不做智能处理,只把同名字段堆起来。适合样本长度已经统一、或数据本身就是张量的任务。

示例

Python
1
2
3
4
5
6
7
8
9
from transformers import default_data_collator
 
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    # 数据已经预先 pad 好时,用最朴素的拼 batch 方式即可
    data_collator=default_data_collator,
)

命令/API/函数
DataCollatorForLanguageModeling

说明
语言模型任务的专用 collator。对 MLM 会随机打 mask;对 Causal LM 则常用于统一 padding 与 label 对齐。

示例

Python
1
2
3
4
5
6
from transformers import DataCollatorForLanguageModeling
 
collator = DataCollatorForLanguageModeling(
    tokenizer=tok,
    mlm=False,   # decoder-only Causal LM 训练时不做 masked LM,改为直接预测下一个 token
)

命令/API/函数
DataCollatorForSeq2Seq

说明
Encoder-Decoder 任务的专用 collator,会同时处理 encoder 输入与 decoder labels,并把 label padding 位置改成 -100 以避开 loss。

示例

Python
1
2
3
4
5
6
7
from transformers import DataCollatorForSeq2Seq
 
collator = DataCollatorForSeq2Seq(
    tokenizer=tok,
    model=model,            # 传入模型后,collator 能结合模型配置处理 decoder 侧细节
    label_pad_token_id=-100 # 交叉熵会忽略 -100,对变长目标序列尤为关键
)
Seq2SeqTrainer:摘要、翻译、问答式生成不要硬套普通 Trainer

摘要、翻译这类 Encoder-Decoder 任务,评估时往往需要真的跑一次 generate() 再计算 ROUGE/BLEU。此时更合适的入口是 Seq2SeqTrainer 与 Seq2SeqTrainingArguments,因为它们把“验证阶段是否调用生成”做成了显式配置。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
 
args = Seq2SeqTrainingArguments(
    output_dir="out_t5_sum",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    # 验证阶段直接调用 model.generate(),而非只看 teacher forcing loss
    predict_with_generate=True,
    generation_max_length=128,  # 限制验证生成长度,防止评估阶段拖垮吞吐
    eval_strategy="epoch",
    save_strategy="epoch",
)
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    # 新版本文档逐步把 tokenizer/processor 收敛到 processing_class 语义
    processing_class=tok,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tok, model=model),
)
CLM 预处理:group_texts 与 block_size

做 Causal LM 预训练或继续预训练时,数据并不总是一条样本对应一条训练序列。官方 example 更常见的做法是先 tokenize,再把多个短文本拼成长 token 流,按固定 block_size 切块。这一步决定了上下文利用率,也决定了 label 的构造方式。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def group_texts(examples, block_size=1024):
    # 先把一批 token 列表拼成连续 token 流,减少短样本浪费
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated["input_ids"])
    # 只保留能整除 block_size 的部分,避免尾部残块长度不齐
    total_length = (total_length // block_size) * block_size
 
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated.items()
    }
    # Causal LM 最常见的数据约定是 labels 和 input_ids 同形,由模型内部完成右移
    result["labels"] = result["input_ids"].copy()
    return result

这类训练里真正的“右移一位”通常由模型内部 loss 逻辑完成,因此数据侧只需要准备与 input_ids 形状一致的 labels。不要在数据预处理阶段再手工 shift 一次,否则会错位两次。

断点续训与导出
Python
1
2
3
4
5
6
7
8
9
# 断点续训:resume_from_checkpoint 可以传具体 checkpoint 路径
# 这里进入真正的续训流程;Trainer 会同时恢复 optimizer/scheduler/trainer_state。
trainer.train(resume_from_checkpoint=True)
 
# 导出最终模型包(建议使用 best checkpoint 对应的权重)
# save_model 会把当前模型权重写成标准 Transformers 模型包结构。
trainer.save_model("models/registry/model_v0001")
# tokenizer 需要和模型目录保持同一路径,部署侧才能直接 from_pretrained。
tok.save_pretrained("models/registry/model_v0001")
best checkpoint、callback 与 early stopping

Trainer 的保存逻辑至少有三种语义:最近一次保存的 last checkpoint,用于恢复训练;指标最优的 best checkpoint,用于上线或离线评测;以及人为指定导出的最终目录。不要把它们混成一个概念,否则“能续训”与“该上线谁”会互相污染。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
 
args = TrainingArguments(
    output_dir="out_cls",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=10,              # 上限给足,真正何时停交给 metric + callback 判定
    eval_strategy="epoch",            # 每轮评估一次,适合样本量中等的分类微调
    save_strategy="epoch",            # 保存节奏和评估节奏对齐,best model 才有明确参照点
    load_best_model_at_end=True,      # 训练结束后自动把内存中的权重切回 best checkpoint
    metric_for_best_model="f1",       # 明确“谁定义最好”;不要默认把 last 当 best
    greater_is_better=True,           # F1 越大越好;loss/perplexity 这类则应设为 False
    save_total_limit=2,               # 保留少量近期 checkpoint,避免长训练把磁盘写爆
)
 
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tok,
    data_collator=collator,
    compute_metrics=compute_metrics,  # compute_metrics 必须返回包含 "f1" 的字典
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=3,   # 连续 3 次评估不改善再停,避免因一次抖动误停
            early_stopping_threshold=0.0 # 只有真正提升才算改善
        )
    ],
)

如果任务真正关心的是生成质量而非 teacher forcing loss, metric_for_best_model 应切到 ROUGE、BLEU、F1、EM 之类更贴近业务目标的指标,而非机械盯住 eval_loss。

generate 与 GenerationConfig(推理生成)

generate 把“下一 token 分布 → 序列”的解码策略(贪心、beam、采样等)封装成统一入口。工程上建议把生成策略固化为 GenerationConfig(或写入服务端配置),避免在业务代码里散落大量参数。

最小生成示例
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
# 分词器必须和生成模型共享同一套 special tokens 与词表
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
  "Qwen/Qwen3-0.6B",
  torch_dtype="auto",  # 让库优先采用权重推荐的加载精度,减少手工选 dtype 带来的不确定性
  device_map="auto",   # 对小模型和原型验证来说最省事;更复杂部署再改显式映射
)
 
gen_cfg = GenerationConfig(
  max_new_tokens=128,  # 限制新生成 token 数,防止测试脚本因为意外长输出拖慢或打爆上下文
  # 打开采样后,temperature/top_p 这类随机性参数才会真正参与解码。
  do_sample=True,
  temperature=0.7,     # 略低于 1 的温度更稳,适合作为“有变化但不过度发散”的默认值
  top_p=0.9,           # nucleus sampling 砍掉长尾 token,常与 temperature 联合使用
)
 
# 整个 batch 搬到模型所在设备,避免 device mismatch
inputs = tok("Explain KV cache in one paragraph.", return_tensors="pt").to(model.device)
with torch.inference_mode():
  # 把解码策略显式放进 generation_config,便于复用与上线固化
  out = model.generate(**inputs, generation_config=gen_cfg)
 
print(tok.decode(out[0], skip_special_tokens=True))
常见参数与含义(部署侧最常用)
参数 作用 工程建议
max_new_tokens 限制生成 token 数 优先用它而非 max_length(后者包含 prompt token)。
do_sample 采样开关 需要稳定输出时关闭采样,并把 temperature=0 或直接不用 temperature。
temperature / top_p 采样随机性与截断 线上服务通常把它们做成可配置策略,按业务风险控制随机性。
eos_token_id / pad_token_id 结束与 padding 的 token id Decoder-only 模型常需要显式设置 pad_token(一般等于 eos_token)。
GenerationConfig 作为可交付配置

GenerationConfig 可以作为独立配置保存和加载。这样“模型权重”与“默认生成策略”就能一起版本化,服务端也能明确区分“模型默认值”和“请求级覆盖值”。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import GenerationConfig
 
gen_cfg = GenerationConfig(
    max_new_tokens=256,   # 把线上默认回复长度写进配置,而非散落在业务代码
    # 这三项一起定义“服务默认输出风格”。
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)
 
# 单独把生成策略保存到模型目录;后续 from_pretrained 会自动读取它。
gen_cfg.save_pretrained("models/registry/model_v0001")
 
# 服务或离线脚本再次加载时,先恢复团队约定的默认策略。
gen_cfg = GenerationConfig.from_pretrained("models/registry/model_v0001")
# 然后再做请求级覆盖;
# 这里改成 128 的含义是“本次调用比服务默认值更保守”,而非修改模型默认配置文件。
gen_cfg.max_new_tokens = 128

推理栈里常见的优先级顺序是:请求级参数覆盖 GenerationConfig,而 GenerationConfig 再覆盖模型内建默认值。把这层关系说清楚,线上回归时才知道到底是谁改了输出风格。

streamer:把 generate 接到 CLI、WebSocket 或 SSE

generate() 默认等整段输出完成后才返回。做交互式 CLI、Web UI 或流式 API 时,更常见的写法是把 token 增量交给 streamer 对象,再由外层线程或事件循环持续消费。

命令/API/函数
TextStreamer

说明
最简单的 stdout streamer,适合命令行 demo 或快速验证 chat template 与生成配置是否正常。

示例

Python
1
2
3
4
5
6
7
8
from transformers import TextStreamer
 
streamer = TextStreamer(
    tok,
    skip_prompt=True,         # 交互场景通常不希望把原 prompt 再打印一遍
    skip_special_tokens=True, # 避免把 eos、role token 直接暴露到终端输出
)
_ = model.generate(**inputs, max_new_tokens=128, streamer=streamer)

命令/API/函数
TextIteratorStreamer

说明
把增量文本暴露成可迭代对象,适合接到 WebSocket、SSE 或自定义前端事件流。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from threading import Thread
from transformers import TextIteratorStreamer
 
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
thread = Thread(
    target=model.generate,
    # generate 放后台线程跑,前台持续取流
    kwargs={**inputs, "max_new_tokens": 128, "streamer": streamer},
)
thread.start()
 
for chunk in streamer:
    # Web 服务里这里通常会改成 yield SSE / WebSocket send
    print(chunk, end="", flush=True)
 
thread.join()

命令/API/函数
AsyncTextIteratorStreamer

说明
面向 async 应用,把流式输出对接到异步事件循环。适合 FastAPI、Starlette 这类 async 服务框架。

示例

Python
1
2
3
4
from transformers import AsyncTextIteratorStreamer
 
# async 服务用它更容易接 Response streaming
streamer = AsyncTextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)

streamer 解决的是“输出如何增量交付”,不改变底层解码算法。真正接服务时,还需要额外处理取消请求、超时、中断清理和背压。

Chat Template(对话模板接入)

Chat Template 把“messages 列表”转换成模型需要的 prompt 格式,并确保 role token、分隔符与结束符一致。工程上建议统一通过 apply_chat_template 生成输入,避免手工拼 prompt 导致格式漂移。

chat template 的存储位置与模板本质

在 Hugging Face 生态里,chat template 往往直接存放在 tokenizer 配置中,本质上是一段模板字符串,很多模型实际使用的是 Jinja2 风格模板。它定义 role 顺序、system/user/assistant 分隔符、工具调用片段、结束符以及是否在末尾补 assistant 起始标记。

Python
1
2
3
4
# 模板通常直接挂在 tokenizer 上;同一模型不同 tokenizer 版本可能对应不同模板
template = tok.chat_template
# 排查输出格式漂移时,第一步常常就是确认线上模板和训练时是否同一份
print(template[:300])
apply_chat_template + generate(最小骨架)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
 
# chat template、special tokens 和词表都由 tokenizer 定义
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
  "Qwen/Qwen3-0.6B",
  torch_dtype="auto",  # 优先沿用模型推荐 dtype,把示例重心放在模板接入而非精度兼容
  device_map="auto",   # 小模型/原型验证可直接自动分配设备
)
 
messages = [
  {"role": "system", "content": "You are a precise assistant."},
  {"role": "user", "content": "Summarize what gradient accumulation is."},
]
 
input_ids = tok.apply_chat_template(
  messages,
  # 在末尾补上 assistant 起始标记,让 generate 从正确的角色位置继续
  add_generation_prompt=True,
  return_tensors="pt",         # 直接返回 tensor,省去再手工编码一次
).to(model.device)
 
with torch.inference_mode():
  # 模板负责 prompt 结构,generate 只负责补全 assistant 回复
  out = model.generate(input_ids, max_new_tokens=128)
 
print(tok.decode(out[0], skip_special_tokens=True))
apply_chat_template 的返回值与工具变量

新版本 Transformers 中, apply_chat_template 不再只是返回 input_ids,而更倾向返回一个完整的 batch 结构。这样可以把 attention_mask、多模态输入和模板相关字段一起传给 generate()。如果模型模板支持工具调用,模板上下文里还可能读取 tools 之类的变量。

Python
1
2
3
4
5
6
7
8
9
formatted = tok.apply_chat_template(
    messages,
    # 让模板渲染和 tokenization 一次完成,避免手工 split/join 再编码
    tokenize=True,
    add_generation_prompt=True,  # 让输出停在 assistant 起始位置,方便 generate 直接续写
    return_tensors="pt",
)
# 新版本更适合把它当作完整 BatchEncoding 处理,而非只拿 input_ids
formatted = formatted.to(model.device)
continue_final_message:让模型续写最后一条消息

默认的 add_generation_prompt=True 语义是“在模板末尾再补一个 assistant 起始标记,然后让模型开始回答”。有些任务希望模型直接续写最后一条未完成消息,并不需要新开一条 assistant 消息,例如 JSON 片段补全、代码补全或工具参数半成品续写。此时更合适的入口是 continue_final_message=True。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
messages = [
    {"role": "user", "content": "Return a JSON object with city and weather."},
    # 最后一条 assistant 消息是半成品;希望模型直接从这里往后补,而非重新起一条 assistant。
    {"role": "assistant", "content": '{"city": "Paris", "weather": "'},
]
 
batch = tok.apply_chat_template(
    messages,
    tokenize=True,
    continue_final_message=True,  # 表示“续写最后一条消息”
    return_dict=True,
    return_tensors="pt",
).to(model.device)

这和 add_generation_prompt=True 是两套不同语义,通常不应同时使用。前者是在模板末尾新开一个 assistant 轮次,后者是在已有消息内部继续补全。

工具调用闭环:tools + parse_response

工具调用模型的真正难点不在于“把工具 schema 塞进 prompt”,而在于把一轮工具调用走完整:模板注入工具定义,模型返回结构化调用意图,应用侧执行真实工具,再把 tool 结果追加回消息历史,最后继续生成。只展示一次性 generate() 往往会漏掉最关键的执行语义。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import get_json_schema
 
def get_current_temperature(city: str) -> str:
    """Get current temperature for a city in Celsius."""
    return f"18C in {city}"
 
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
messages = [{"role": "user", "content": "What is the weather in Paris?"}]
# 由函数签名和 docstring 生成工具 schema,避免手写另一份 JSON Schema。
tools = [get_json_schema(get_current_temperature)]
 
for _ in range(3):
    batch = tok.apply_chat_template(
        messages,
        tools=tools,                 # 模板在渲染时把工具定义一并注入上下文
        add_generation_prompt=True,
        return_dict=True,            # 返回完整 batch,便于直接喂给 generate
        return_tensors="pt",
    ).to(model.device)
 
    out = model.generate(**batch, max_new_tokens=256)
    # 只解码新增部分,避免把整个历史消息都再解一遍。
    delta = out[0][batch["input_ids"].shape[1]:]
    text = tok.decode(delta, skip_special_tokens=False)
    # parse_response 会把模型输出解析成 message dict,可能包含 tool_calls。
    msg = tok.parse_response(text)
    messages.append(msg)
 
    if "tool_calls" not in msg:
        break
 
    tool_call = msg["tool_calls"][0]
    result = get_current_temperature(**tool_call["arguments"])
    # 工具执行结果必须以 tool message 形式写回历史,模型下一轮才能消费它。
    messages.append({"role": "tool", "content": str(result)})

这条链路一旦确定,训练数据、离线评测和线上服务最好都沿用同一套消息格式。否则就会出现线上是工具调用模板,训练集却只是普通对话模板的格式漂移问题。

训练数据的 chat template 对齐

如果模型的 tokenizer 自带 chat template,SFT 数据建议按同一模板构造训练样本;否则推理时的对话格式会与训练时不一致,表现为“角色混淆”“结束符异常”“输出风格漂移”。

device_map / torch_dtype(加载策略与显存治理)

大模型加载的关键旋钮是:把权重放在哪(GPU/CPU/磁盘)与用什么 dtype(fp32/fp16/bf16)。 device_map="auto" 会尝试把层自动分配到设备上,通常需要安装 accelerate; torch_dtype="auto" 会按权重与硬件能力选择合适 dtype。

Python
1
2
3
4
5
model = AutoModelForCausalLM.from_pretrained(
  "./model_dir",       # 本地目录需要是 save_pretrained 导出的标准模型包结构
  torch_dtype="auto",  # 优先按权重声明与硬件能力选择精度,通常是最稳的起点
  device_map="auto",   # 自动把权重切到 GPU/CPU;原型验证方便,精细部署再手工接管
)
常见坑(高频报错与修复动作)
现象 根因 修复动作
ImportError: Using device_map requires Accelerate 启用了 device_map,但环境缺少 accelerate 安装 pip install -U accelerate,或移除 device_map 并手动 model.to(device)。
Decoder-only 推理报 padding 相关错误 tokenizer 没有 pad_token tok.pad_token = tok.eos_token,并设置 pad_token_id。
推理输出乱码或 EOS 提前结束 tokenizer 与模型不匹配,或 chat template 不一致 确保 tokenizer 与模型来自同一目录;推理统一用 apply_chat_template。
本地目录加载失败(找不到 config/tokenizer) 目录并非标准模型包结构 用 model.save_pretrained 与 tok.save_pretrained 导出;检查是否存在 config.json。
OOM 或极慢 dtype/设备放置策略不合理 优先 torch_dtype="auto" + device_map="auto";必要时启用更强的推理引擎(vLLM/TensorRT-LLM)。
加载第三方模型需要 trust_remote_code=True 模型仓库包含自定义 Python 代码 在受控环境中审计代码后再开启;离线导出时固定 commit hash,避免代码漂移。
PEFT 与微调技术详解

参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)的工程核心是把“可训练参数”从完整模型权重中剥离出来:训练与分发只关心适配器(adapter)的小文件,线上推理再把适配器挂到同一个 base checkpoint 上复用。这样既减少训练时的显存与优化器状态开销,也把多任务/多域适配的存储成本压到可控范围。

安装与版本对齐

PEFT、Transformers、TRL 与 bitsandbytes 在接口上是强耦合组合。工程上以“同一套 requirements 锁定版本”作为默认策略,避免出现 PEFT 与 Transformers 的适配器注入逻辑不一致、或 TRL 的 Trainer 参数签名变化导致脚本失效。

Shell
1
2
3
4
5
# 训练/微调常见最小集合
pip install -U transformers accelerate datasets peft trl safetensors
 
# QLoRA / 4bit 量化微调需要
pip install -U bitsandbytes
PEFT 的对象模型:base 与 adapter

PEFT 的存储与加载分两层:

  • base:Transformers 模型原始 checkpoint(通常很大、可复用、版本需固定)。
  • adapter:PEFT 生成的小文件(含 adapter_config 与 adapter weights),可多份并存,用于不同任务/域。

标准做法是:训练输出目录只保存 adapter;上线推理时先加载 base,再加载 adapter。这样 adapter 目录可被当作“制品”(artifact)管理,支持灰度、回滚与多 adapter 切换。

常用API

命令/API/函数
LoraConfig

说明
LoRA/QLoRA 的配置对象

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from peft import LoraConfig, TaskType
 
cfg = LoraConfig(
    # 指明这是 decoder-only 语言模型,PEFT 会按自回归任务布置 adapter
    task_type=TaskType.CAUSAL_LM,
    # rank 决定低秩分支容量;r 越大,参数量和适配能力越强
    r=16,
    # alpha 控制低秩更新的缩放强度,避免 adapter 更新过弱
    lora_alpha=8,
    # 对 LoRA 分支做轻度 dropout,用来缓和微调过拟合
    lora_dropout=0.05,
    # 优先覆盖注意力投影层;模块名必须和模型源码一致
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

命令/API/函数
get_peft_model

说明
把 base model 包装成可训练的 PeftModel

示例

Python
1
2
from peft import get_peft_model
peft_model = get_peft_model(base_model, cfg)

命令/API/函数
PeftModel.from_pretrained

说明
给已加载的 base model 挂载某个 adapter

示例

Python
1
2
from peft import PeftModel
model = PeftModel.from_pretrained(base_model, "adapter_dir")

命令/API/函数
model.load_adapter / model.set_adapter

说明
在同一 base 上继续挂载其它 adapter,并显式切换当前激活的 adapter。多域推理与灰度切换时非常高频。

示例

Python
1
2
model.load_adapter("adapter_b", adapter_name="b")
model.set_adapter("b")

命令/API/函数
model.add_adapter

说明
在当前 base 上新增一份全新的 adapter 配置,常用于“一个底座上继续开第二条训练线”。

示例

Python
1
2
3
4
5
6
7
8
cfg_b = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
)
model.add_adapter("domain_b", cfg_b)

命令/API/函数
model.save_pretrained

说明
保存 adapter(不覆盖 base)

示例

Python
1
model.save_pretrained("adapter_out")

命令/API/函数
model.merge_and_unload

说明
把 adapter 合并进 base 权重并卸载 adapter(用于导出单体权重)

示例

Python
1
merged = model.merge_and_unload()

命令/API/函数
model.merge_adapter / model.unmerge_adapter

说明
临时把当前 adapter 合并进 base,再按需撤销。适合做“先合并测一下吞吐/精度,再回到可切换 adapter 结构”的实验。

示例

Python
1
2
3
model.merge_adapter()
# ... 在当前进程里做一轮延迟/显存/质量验证 ...
model.unmerge_adapter()

命令/API/函数
prepare_model_for_kbit_training

说明
把量化后的 base 调整到可训练状态,通常在 QLoRA 里和 BitsAndBytesConfig 一起出现。

示例

Python
1
2
3
from peft import prepare_model_for_kbit_training
 
base = prepare_model_for_kbit_training(base)

命令/API/函数
AutoPeftModelForCausalLM

说明
把“adapter 目录 + 记录在配置里的 base 身份”直接还原成完整可推理对象,适合把 adapter 目录当成制品交付。

示例

Python
1
2
3
4
5
6
7
from peft import AutoPeftModelForCausalLM
 
model = AutoPeftModelForCausalLM.from_pretrained(
    "adapter_out",
    torch_dtype="auto",
    device_map="auto",
)

命令/API/函数
get_peft_model_state_dict / set_peft_model_state_dict

说明
只提取或恢复 adapter 相关参数,便于接自定义 checkpoint、FSDP 或分片保存系统。

示例

Python
1
2
3
4
from peft import get_peft_model_state_dict, set_peft_model_state_dict
 
adapter_state = get_peft_model_state_dict(model)
set_peft_model_state_dict(model, adapter_state)

命令/API/函数
model.print_trainable_parameters

说明
自检:确认“只训练 adapter”而非误训全参

示例

Python
1
model.print_trainable_parameters()
save_pretrained 的高级选项:save_embedding_layers

当 LoRA 训练同时动到了 embedding 层,或者训练过程中做过 resize_token_embeddings,单纯保存 adapter 增量并不总是足够。PEFT 的 save_pretrained(..., save_embedding_layers=...) 用来显式控制“是否把 embedding 层也随 adapter 一起保存”。这在加新 token、改词表或把 embedding 本身纳入 target_modules 时尤其关键。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# auto 会根据当前 adapter 是否覆盖 embedding、以及 embedding 是否在微调中被调整过来判断。
model.save_pretrained(
    "adapter_out",
    save_embedding_layers="auto",
)
 
# 当你明确知道 embedding 层已经被修改过时,可以直接强制保存。
model.save_pretrained(
    "adapter_out_with_embed",
    save_embedding_layers=True,
)
 
# 如果确认 embedding 没被动过,只想让 adapter 目录尽可能轻量,可以显式关闭。
model.save_pretrained(
    "adapter_out_small",
    save_embedding_layers=False,
)

这并非“目录大小优化”这么简单。若训练时扩过词表,但导出时没把相关 embedding 状态带走,推理侧最常见的后果就是:tokenizer 已经认识新 token,模型权重却没有与之对应的 embedding 或输出头权重。

ephemeral_gpu_offload:让 adapter 装配借道 GPU

有些 adapter 路线在加载、合并或切换阶段会出现“CPU 太慢,但常驻 GPU 又太贵”的矛盾。PEFT 提供的 ephemeral_gpu_offload 属于折中方案:平时仍把主要状态放在 CPU/低成本位置,但在关键装配步骤临时借用 GPU,加速权重处理过程。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map={"": "cpu"},      # 先把底座稳定落在 CPU,避免一开始就抢占线上 GPU
)
 
model = PeftModel.from_pretrained(
    base,
    "adapter_dir",
    # 打开后,PEFT 会在需要时临时把相关计算借道到 GPU,而非全程只走 CPU。
    ephemeral_gpu_offload=True,
    # 给 CPU 明确内存预算,防止大模型加载时被系统 OOM killer 直接杀掉。
    max_memory={"cpu": "256GiB"},
    device_map={"": "cpu"},
)

这类参数优化的是“加载/装配路径”,并非训练吞吐主路径。只有当你真的遇到 adapter 加载或 merge 很慢、而机器又有可借用 GPU 时,它才值得作为工程旋钮引入。

LoRA:PEFT 的主力路径

LoRA(Low-Rank Adaptation)通过对线性层权重施加低秩增量,让可训练参数规模与显存开销显著下降。实际工程难点集中在两处:target_modules 怎么选,以及如何保存/加载/合并。

target_modules:如何定位注入点

target_modules 是 LoRA 注入的“模块名匹配规则”,用于指定 base 模型里哪些子模块会被替换/包裹。这一选择与“按层类型挑选线性层”不同:它依赖模型内部的命名约定。不同架构的模块命名差异很大(Llama 系列常见 q_proj/k_proj/v_proj/o_proj,也有模型用 Wq/Wk/Wv/Wo 或把投影层藏在自定义块里)。稳定做法是先枚举可疑线性层,再按名字筛选。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
 
def list_linear_module_names(model: torch.nn.Module):
    names = []
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):
            # LoRA 最常挂在线性层上,先把所有候选层名字列出来再筛 target_modules
            names.append(name)
    return names
 
# 先看前几十个,判断该模型到底叫 q_proj 还是别的名字
for n in list_linear_module_names(model)[:50]:
    print(n)

LoRA 常见的注入策略是“注意力投影层优先”:先只覆盖 attention 的 Q/K/V/O,再根据效果与显存预算扩展到 FFN 的投影层(例如 gate/up/down)。

modules_to_save:适配器之外还要训练哪些层

adapter 之外偶尔需要训练额外模块(例如分类头、语言模型的 lm_head、或新加 token 的 embedding)。这类模块可以通过 modules_to_save 声明为可训练并随 adapter 一起保存,避免“训练时更新了头部,但保存的 adapter 不包含它”的上线故障。

LoRA 最小训练骨架(Transformers Trainer)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, get_peft_model
 
# 选 1B 级模型做骨架示例,普通单卡更容易跑通完整流程
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# fast tokenizer 预处理吞吐更高,适合训练前批量 map
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    # 先沿用模型推荐 dtype,把精力放在 LoRA 训练链路而非精度兼容问题上
    torch_dtype="auto",
    device_map="auto",   # 原型阶段直接自动分配设备;正式训练再切换到显式分布式配置
)
 
cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,                         # 按自回归语言模型注入 adapter
    # 给 1B 级模型一个中等容量的 LoRA 分支
    r=16,
    # alpha 控制 adapter 更新幅度;这里偏保守,先求训练稳定
    lora_alpha=8,
    # 小比例 dropout 用来缓和 SFT 对训练集的死记硬背。
    lora_dropout=0.05,
    # 覆盖注意力主干投影层,是最常见的 LoRA 起点。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
# 把 LoRA 配置真正注入到 base 模型里;
# 之后训练入口看到的是 PeftModel,优化器也应只更新 adapter 参数。
model = get_peft_model(base, cfg)
# 训练前先核对参数占比,防止 target_modules 写错导致零参或全参训练
model.print_trainable_parameters()
 
# 用公开 SFT 数据验证数据流;真实项目通常还要先做模板标准化
ds = load_dataset("trl-lib/Capybara", split="train")
 
def tokenize(example):
    # 多轮 messages 在真实项目里应先渲染成统一训练文本
    text = example["text"] if "text" in example else str(example)
    # 先裁到 1024,避免示例因为极长样本直接 OOM
    return tok(text, truncation=True, max_length=1024)
 
# 删除原始列,只保留模型 forward 真正需要的输入字段
ds = ds.map(tokenize, remove_columns=ds.column_names)
 
args = TrainingArguments(
    # adapter checkpoint 和 trainer 状态单独存放,便于和 merge 产物区分
    output_dir="out_lora_adapter",
    per_device_train_batch_size=1,      # LoRA 虽省显存,但长序列下单卡通常仍从 1 起步
    gradient_accumulation_steps=8,      # 用微步累积把有效 batch 拉大,让优化更平滑
    # LoRA 常用学习率显著高于全参微调,因为真正更新的参数规模小
    learning_rate=2e-4,
    num_train_epochs=1,                 # 骨架示例先保证完整闭环,不在这里讨论最优停点
    # 展示最常见的混合精度开关;新卡也可根据硬件改成 bf16
    fp16=True,
    logging_steps=10,                   # 高频看 loss,更容易发现模板或 label 对齐问题
    save_steps=200,                     # 长任务中断时至少能从最近 adapter checkpoint 恢复
)
 
trainer = Trainer(
    model=model,       # 这里传的是 PeftModel;训练循环只会更新 LoRA 参数
    args=args,         # 保存、日志和 batch 行为统一由 args 接管
    train_dataset=ds,  # ds 已经只剩模型输入字段,Trainer 可以直接喂给 forward 计算 loss
)
 
trainer.train()
 
# 推荐默认只保存 adapter;部署时再决定是否 merge 成单体权重
model.save_pretrained("out_lora_adapter")
保存、加载与多 adapter 切换

PEFT 支持一个 base 上挂多份 adapter,并在推理时切换 active adapter。这种能力适合“同一底座,多业务域”的线上形态。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
base = AutoModelForCausalLM.from_pretrained(
  model_id,
  torch_dtype="auto",  # base 先按常规方式加载;多个 adapter 共享这一份底座权重
  device_map="auto",
)
 
# 先挂第一份 adapter,并给它取业务可读名字
model = PeftModel.from_pretrained(base, "adapter_a", adapter_name="a")
model.set_adapter("a")  # 显式声明当前激活哪份 adapter,避免多 adapter 并存时弄错生效对象
 
# 第二份 adapter 叠挂到同一个 base 上,便于线上快速切换域能力
model.load_adapter("adapter_b", adapter_name="b")
 
# 推理前切到目标 adapter;没有这一步时,模型仍可能沿用上一份激活配置
model.set_adapter("b")
adapter 生命周期补全:禁用、删除与梯度开关

多 adapter 系统里,真正困难的是在训练、评估、A/B 和回收阶段精确控制它们。PEFT 在这方面给了几组很实用的生命周期接口:

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
# 先恢复一份共享底座;后续多份 adapter 都挂在这一个 base 上。
# torch_dtype="auto" 让权重按模型仓库推荐精度加载,避免例子把焦点带偏到 dtype 手调。
# device_map="auto" 让加载器自动把权重分配到可见设备,便于直接进入切换与对照流程。
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
model = PeftModel.from_pretrained(
    base,
    "adapter_a",       # 第一份 adapter 的目录;内部应包含 adapter 权重与 adapter_config。
    adapter_name="a",  # 给当前 adapter 取一个短名字,后续切换、导出与删除都用它引用。
    # 先以推理态挂载,避免恢复模型时误把旧 adapter 直接放进训练图。
    is_trainable=False,
    # 默认会把部分 adapter 权重提升到更稳的 dtype,减少低精度训练/推理异常。
    autocast_adapter_dtype=True,
)
# 第二份 adapter 继续挂到同一个 base 上。
# 这样两个域能力共享一份底座权重,切换成本只落在 adapter 层。
model.load_adapter("adapter_b", adapter_name="b", is_trainable=False)
 
# disable_adapter() 是最直接的 base 对照实验入口:
# 同一份请求,不切模型对象,只临时绕过 adapter 路径。
with model.disable_adapter():
    base_only = model.generate(**inputs)  # 这里拿到的是“同一底座、关闭全部 adapter”的基线输出。
 
# 只打开目标 adapter 的梯度,适合“多 adapter 并存,但本轮只训练其中一份”。
model.set_requires_grad(["b"], requires_grad=True)
model.set_adapter("b")  # 把 active adapter 切到 b;后续 forward / generate 都走这份 adapter。
 
# 导出时可以只挑某几个 adapter,避免把实验性分支一起打包。
model.save_pretrained("out_adapters", selected_adapters=["b"])
 
# delete_adapter() 用于清理不再需要的驻留 adapter,减少对象复杂度。
model.delete_adapter("a")

is_trainable 决定加载完成后 adapter 是“默认参与训练”还是“默认按推理态挂载”; autocast_adapter_dtype 决定是否把 adapter 权重提升到更稳的 dtype; selected_adapters 则决定最终导出哪些 adapter。三者分别控制的是训练态、数值态和制品态。

这组接口的工程意义很直接:线上多租户场景要切域能力,实验阶段要做 base vs adapter 对照,持续训练时要只放开某几份 adapter,回收旧版本时要清理驻留对象。没有这层生命周期控制,最终只会得到一堆“目录里能存文件,但系统行为不可控”的 adapter 包。

多 adapter 组合与加权

当两份 adapter 分别学习了不同域能力,而业务又希望在同一个请求里组合它们时,PEFT 提供了“先加载多份 adapter,再按权重生成一份组合 adapter”的路线。这种做法常用于实验性融合、域能力叠加或上线前的快速拼接验证。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 先把两份 adapter 都挂到同一个 base 上。
model.load_adapter("adapter_a", adapter_name="a")
model.load_adapter("adapter_b", adapter_name="b")
 
# 生成一份新的加权 adapter;weights 控制每份 adapter 的影响强度。
model.add_weighted_adapter(
    adapters=["a", "b"],
    weights=[0.7, 0.3],         # 让 domain_a 保持主导,domain_b 只补充局部能力
    # 新组合会以一份独立 adapter 的形式注册,方便后续切换与导出。
    adapter_name="blend_ab",
)
 
# 推理时切到组合后的 adapter;业务代码不需要感知内部是多份融合而来。
model.set_adapter("blend_ab")

加权组合的本质仍然是“线性组合若干已存在的 adapter 参数”。它并非魔法平均,也不会自动解决词表、模板、任务定义不一致的问题;只有当几份 adapter 共享同一个 base、相近注入位置和可兼容任务语义时,这条路才有意义。

add_weighted_adapter 的组合类型与约束

add_weighted_adapter 不只有“给两份 adapter 配个权重”这么简单。它背后对应的是一组不同的组合算法,而不同算法对 rank、一致性和内存峰值有明确要求。

组合类型 适用语义 工程约束
linear 最直接的线性加权,适合“几份 LoRA 语义接近,只想做平滑融合” 参与融合的 adapter 通常需要相同 rank,否则组合会直接失去可比性
cat 把多个 adapter 的低秩空间直接拼接到更高 rank 的新 adapter 结果 rank 约等于各 adapter rank 之和,容量会变大,显存与导出体积也会同步增加
ties / dare_ties / magnitude_prune 更偏“从多个 adapter 中抽取共识或高价值权重” 通常要配合 density 一起用,density 表示保留多大比例的权重信息
svd / *_svd 先合成,再通过 SVD 压回目标 rank 更像“压缩后的融合”,但对 dtype 与数值稳定性更敏感,低精度环境里要额外验证

选择组合类型时,先问清楚目标是什么:是做离线实验、希望快速看到两域混合效果;还是想交付一份稳定可部署的新 adapter。前者可以容忍更激进的组合方式,后者通常更偏向 linear 或经过充分验证的稀疏化融合。 cat 的风险尤其工程化,因为它会直接拉高 rank,导致后续训练、保存与推理都变重。

如果组合完成后还要按样本级切换 adapter_names,需要先确认当前模型没有处于 merged 状态。部分组合、merge 与多 adapter 路径是互斥的,先 merge 后再按请求细切换,通常只会把系统带进不可逆状态。

AutoPeftModel:把“base + adapter”当作一个可加载制品

如果 adapter 目录已经完整记录了 base 模型身份与 PEFT 配置,加载时不一定要先手工恢复 base,再显式 PeftModel.from_pretrained。PEFT 提供了 AutoPeftModel* 家族,把“还原整个 adapter 制品”做成统一入口。

Python
1
2
3
4
5
6
7
8
9
10
from peft import AutoPeftModelForCausalLM
 
model = AutoPeftModelForCausalLM.from_pretrained(
    "adapter_out",      # adapter 目录里应当已经记录 base 模型来源与适配器配置
    # 仍沿用常规加载策略;AutoPeftModel 解决的是对象装配,而非精度策略
    torch_dtype="auto",
    device_map="auto",
)
# 返回的对象已经是“可直接推理/继续训练的 PEFT 模型”,
# 不必再手工恢复 base 然后 from_pretrained 一次。
adapter 的 state_dict 工具

复杂训练栈里,adapter 状态并不总是通过 save_pretrained 直存直取。做 FSDP、分片保存、或和自定义 checkpoint 体系对接时,往往需要直接拿到“只包含 PEFT 参数”的状态字典。

Python
1
2
3
4
5
6
7
8
9
from peft import get_peft_model_state_dict, set_peft_model_state_dict
 
# 只提取 adapter 相关参数,便于塞进自定义 checkpoint 结构
adapter_state = get_peft_model_state_dict(model)
 
# ... 这里可以把 adapter_state 写进你自己的 checkpoint 目录或分布式存储 ...
 
# 恢复时只把 adapter 参数写回,不污染 base 权重
set_peft_model_state_dict(model, adapter_state)
merge:合并成单体权重(用于导出/部署)

merge_and_unload() 用于把 LoRA 权重写回 base 权重,再移除 adapter 结构,得到“标准 Transformers 模型结构”。工程上这通常用于:

  • 导出到单体 checkpoint(例如给不支持 adapter 的推理引擎)。
  • 降低线上复杂度(不需要两段加载与 adapter 切换)。

合并前需要确认 base 权重处于可写入的浮点 dtype(fp16/bf16/fp32)。对 4-bit 量化权重,合并通常不作为默认路径。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
base = AutoModelForCausalLM.from_pretrained(
  model_id,
  # merge 需要可写入的浮点权重;这里显式用 fp16,避免量化权重无法直接写回
  torch_dtype="float16",
  device_map="cpu",       # merge 常放在 CPU 或大显存机器上做,避免和线上推理实例争 GPU
)
model = PeftModel.from_pretrained(base, "adapter_out")  # 先恢复“base + adapter”的组合形态
merged = model.merge_and_unload()  # 把 LoRA 增量写回 base,并移除 adapter 包装层
# 导出成标准 Transformers 单体模型包,便于不支持 adapter 的后端加载
merged.save_pretrained("out_merged_full")
merge_adapter vs merge_and_unload vs unload

这三个接口名字接近,但交付语义完全不同。实际项目里最常见的错误,就是把 merge_and_unload() 当作一个“随时可以撤销”的性能开关。

接口 会发生什么 适合什么场景
merge_adapter() 把当前 adapter 临时合到 base 权重里,但仍保留 adapter 结构与回退能力 做基准测试、对比 merged 与 unmerged 性能,或暂时消除推理期 adapter 额外开销
unmerge_adapter() 撤销前一次 merge_adapter() 的效果,回到“base + adapter”分离态 实验阶段切回可编辑、可切换 adapter 的状态
merge_and_unload() 返回一个新的、已经写回 adapter 权重的标准模型对象,并移除 PEFT 包装 导出单体权重,交给不理解 adapter 的推理后端或交付团队
unload() 直接去掉 adapter,不做权重合并 只想回到 base 模型做对照,或清理对象状态
Python
1
2
3
4
5
6
7
8
9
10
11
12
# 这条路线适合做“可回退”的 merged benchmark。
model.merge_adapter()
bench = model.generate(**inputs)
model.unmerge_adapter()
 
# 这条路线适合做最终导出。
# 返回值是一个新的普通 Transformers 模型对象,并非原对象原地变化。
merged_model = model.merge_and_unload()
merged_model.save_pretrained("out_merged_full")
 
# 如果目标只是临时回到 base 路径,不需要做 merge。
base_only_model = model.unload()

从对象生命周期看, merge_adapter() 仍然站在“PEFT 世界”里,保留了继续切换、反向撤销与再保存 adapter 的空间; merge_and_unload() 则直接跨到“普通 Transformers 模型”的世界里。后者更适合制品交付,前者更适合实验与排障。

QLoRA:量化权重 + LoRA 的组合

QLoRA 的工程语义是:base 权重以 4-bit 量化形式加载并冻结,反向传播只更新 LoRA 参数;量化层负责把反向信号“传递”到 LoRA,而非更新量化权重本身。

QLoRA 关键 API(BitsAndBytesConfig + prepare_model_for_kbit_training)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
model_id = "mistralai/Mistral-7B-v0.1"
 
bnb_cfg = BitsAndBytesConfig(
    # base 权重按 4bit 量化加载,把显存压力压到 LoRA 可训练场景能承受的范围
    load_in_4bit=True,
    # NF4 是 QLoRA 常见默认量化格式,对权重分布更友好
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,            # 再做一层量化压缩量化常数,进一步省显存
    # 前向/反向实际计算用 bf16,兼顾速度和数值稳定性
    bnb_4bit_compute_dtype=torch.bfloat16,
)
 
# tokenizer 仍按原模型加载;量化不影响词表
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,  # 把 4bit 配置传进加载流程,base 会以量化层形式构建
    device_map="auto",            # 让量化后的权重自动落到可见设备
)
 
# 这一步会冻结量化底座、修正部分层的 dtype/训练标志,
# 让后续注入的 LoRA 建立在“已经准备好训练”的量化骨架上。
base = prepare_model_for_kbit_training(base)
 
cfg = LoraConfig(
    # QLoRA 在任务语义上仍然是 LoRA,只是 base 换成量化权重
    task_type=TaskType.CAUSAL_LM,
    r=16,                                                 # 给量化底座配一个中等 LoRA 容量
    # 控制 adapter 更新幅度,先用保守值保证训练稳定
    lora_alpha=8,
    # 继续对 LoRA 分支做轻度正则化。
    lora_dropout=0.05,
    # 注入点仍优先选注意力投影层。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
# 注入 LoRA 后,训练仍只更新 adapter;
# 不同之处在于底层 base 已经是 4bit 量化形式。
model = get_peft_model(base, cfg)
# 再次确认训练参数规模,确保没有误把量化 base 放开
model.print_trainable_parameters()

prepare_model_for_kbit_training 必须发生在 LoRA 注入之前,因为它负责冻结量化底座、处理某些层的 dtype 与训练标志,让后续注入的 adapter 建立在“已经准备好训练”的量化骨架上。顺序颠倒后最常见的问题是:LoRA 已经挂上了,但底层量化模块仍处于不适合训练的状态。

prepare_model_for_kbit_training 的真实副作用

这个函数的重要性不在“名字看起来像初始化工具”,而在它确实会改模型状态。工程上至少要把它理解成四件事:

  • 冻结 base 参数。QLoRA 的核心前提就是“量化底座不更新,只让 LoRA 学”。如果这一步没发生,训练就会悄悄滑向“低比特全参微调”,显存、数值稳定性和保存语义都会变坏。
  • 处理 layer norm、embedding 等层的 dtype 与训练标志。这样做的目的核心是让低比特底座上的前向与反向路径更稳定。
  • 为 gradient checkpointing 和输入梯度路径做准备。很多长上下文 QLoRA 脚本把这一步和 gradient checkpointing 绑定考虑,本质上是在给“可训练 LoRA + 低比特底座”这条路径收拾好反向传播现场。
  • 建立后续保存/恢复的前提。LoRA 注入后之所以还能清晰地区分“哪些是 adapter,哪些是 base”,前提正是底座已经被预处理为冻结状态。

因此, prepare_model_for_kbit_training 不该被理解成一个可有可无的 helper。它更接近“把普通量化模型变成可做 QLoRA 训练的骨架”的入口。

量化 + adapter 的工程约束
  • 合并策略:4-bit 量化权重通常不作为合并目标。需要单体权重时,常见流程是“重新加载 fp16/bf16 base → 挂载 adapter → merge → 导出”。
  • 训练开关:Decoder-only 模型训练时常需要关闭 use_cache,并配合 gradient checkpointing 控制显存。
  • 部署形态:adapter 目录天然适合做制品;量化 base 属于环境相关资产(与推理后端、算子实现与硬件强相关),需要独立版本管理。
量化后端 训练期常见形态 导出与 merge 判断
bitsandbytes 4-bit / 8-bit 最常见的 QLoRA 训练底座 训练期可以挂 LoRA;需要单体权重时,通常回到浮点 base 再 merge,而非直接在量化权重上产出最终制品
GPTQ / AWQ 更偏推理期量化制品 很多真实工作流把它们当部署格式而非训练骨架,LoRA 合并与继续训练路径通常更受限
AQLM / HQQ / 其它非常规低比特后端 需要核对专门兼容路径 优先把 adapter 作为独立制品管理,除非文档明确说明支持 merge;否则更稳的做法是分离保存而非强行写回量化权重

这一层判断的核心只有一句话:训练骨架和部署制品并非同一个概念。很多量化后端非常适合把推理显存压下来,却并不天然适合做“最终 merge 交付”的承载格式。

IA3:更轻量的向量型适配器

IA3(Infused Adapter by Inhibiting and Amplifying Inner Activations)通过在注意力与前馈模块中注入少量可训练向量来缩放激活值。它的可训练参数通常比 LoRA 更少,适合“极低成本的快速适配”场景。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import IA3Config, TaskType, get_peft_model
 
model_id = "bigscience/mt0-large"
# 仍然沿用原模型 tokenizer;IA3 只改网络内部缩放向量
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForSeq2SeqLM.from_pretrained(
  model_id,
  torch_dtype="auto",  # 先按模型建议精度加载,避免把示例重点带偏到 dtype 调参
  device_map="auto",
)
 
cfg = IA3Config(
    task_type=TaskType.SEQ_2_SEQ_LM,  # 这里是 encoder-decoder 模型,不再是 Causal LM
    # IA3 通过缩放关键激活通路工作;具体名字必须按模型实现核对
    target_modules=["k", "v", "wo"],
)
model = get_peft_model(base, cfg)  # 把 IA3 向量注入到指定模块,参数量通常比 LoRA 还小
model.print_trainable_parameters()  # 先确认确实只打开了少量 IA3 参数
Prompt Tuning:软提示词(soft prompt)

Prompt tuning 把“要学习的东西”压缩为一段可训练的虚拟 token embedding(virtual tokens),base 权重保持冻结。它的工程接口与 LoRA 不同:训练的是提示向量,并非注入线性层的低秩权重。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import (
    PromptTuningConfig,
    PromptTuningInit,
    TaskType,
    get_peft_model,
)
 
model_id = "bigscience/bloomz-560m"
# prompt tuning 依赖 tokenizer 对虚拟 token 前后边界的处理
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
  model_id,
  torch_dtype="auto",  # base 保持冻结,加载精度主要影响推理/前向显存而非训练参数规模
  device_map="auto",
)
 
cfg = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,                          # 软提示仍服务于自回归生成任务
    # 用一段真实文本初始化虚拟 token,比完全随机初始化更容易收敛
    prompt_tuning_init=PromptTuningInit.TEXT,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    # 虚拟 token 数决定软提示容量;太少表达不够,太多又增加上下文占用
    num_virtual_tokens=8,
    # 明确使用哪套 tokenizer 来把初始化文本映射到 embedding 空间
    tokenizer_name_or_path=model_id,
)
 
# 把可训练对象限制为虚拟 token embedding,而非网络层权重
model = get_peft_model(base, cfg)
model.print_trainable_parameters()  # 先确认只打开了 prompt 参数

Prompt tuning 的数据组织更接近“提示 + 输入 + 输出”的模板化任务。工程上需要明确:提示文本、虚拟 token 数量、以及 tokenizer 的特殊 token 处理方式,三者必须一致,否则会出现训练可收敛但推理表现异常的对齐问题。

与 Transformers / TRL 的集成方式
Transformers:PeftAdapterMixin

Transformers 在模型类上集成了适配器管理接口,典型能力包括 add_adapter、 load_adapter、 set_adapter 与适配器保存。LoRA/IA3/AdaLoRA 属于常见的“直接集成”方法;prompt tuning 等提示类方法通常直接用 PEFT 库完成更稳定。

TRL:把 PEFT 当作模型构造步骤

TRL 的 Trainer(SFT/DPO/GRPO/PPO 等)工程上更适合把 PEFT 当作“模型构造步骤”:先用 Transformers 加载 base,再用 PEFT 包一层 adapter,把得到的 PeftModel 直接传给 TRL Trainer。这样可以绕开不同 TRL 版本对 peft_config 参数支持度不一致的问题。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
 
# 先拿公开数据验证 TRL+PEFT 训练栈;真实项目通常先做模板清洗
train_ds = load_dataset("trl-lib/Capybara", split="train")
 
cfg = SFTConfig(
    # 保存 adapter checkpoint 与 trainer 状态;TRL 仍沿用 HF 的落盘约定
    output_dir="out_trl_lora",
    per_device_train_batch_size=1,      # 大模型 + LoRA 的单卡安全起点通常仍然是 1
    gradient_accumulation_steps=8,      # 通过累积把有效 batch 拉大,减少小 batch 更新抖动
    learning_rate=2e-4,                 # adapter 参数少,学习率通常可以比全参微调更高
    logging_steps=10,                   # 早期高频看 loss,便于快速发现模板或 masking 配错
    save_steps=200,                     # 周期性保存 adapter,避免长任务中断后完全重来
    # 示例保留最小闭环;真实项目应按验证集或 reward 指标决定停点
    num_train_epochs=1,
)
 
trainer = SFTTrainer(
    # 这里直接传已经挂好 LoRA 的 PeftModel,绕开不同 TRL 版本 peft_config 支持差异
    model=model,
    # tokenizer 决定模板拼接后的切分方式,也决定 labels 的 token 对齐
    tokenizer=tok,
    # 数据集需要已经整理成 TRL 能消费的字段;否则会在格式化阶段报错
    train_dataset=train_ds,
    args=cfg,               # SFTConfig 是 TRL 对 TrainingArguments 的任务化扩展
)
 
trainer.train()
# 继续保留 adapter 形态,便于后续叠加、切换或 merge 导出
model.save_pretrained("out_trl_lora_adapter")
配置化微调工作台

PEFT 原生 API 适合需要精确控制模型对象、adapter 生命周期和 checkpoint 语义的工程;Unsloth、LLaMA-Factory、Axolotl 这类工作台则把常见微调路径封装成配置和命令。它们的价值在于减少重复脚本,让团队用同一套配置描述模型、数据、模板、LoRA、量化、训练参数和导出路径。

工作台 主要能力 适合用法
Unsloth 低显存 LoRA/QLoRA、快速 SFT、本地推理与 GGUF/Ollama/vLLM 导出路径。 单卡或少卡快速验证数据、模板和 adapter 效果。
LLaMA-Factory 用 YAML、CLI 和 WebUI 管理 SFT、RM、DPO、PPO、导出与 merge。 希望把训练配方交给配置文件管理,减少手写 Trainer 脚本。
Axolotl 复杂 QLoRA/FSDP 配方、sample packing、DPO/GRPO、vLLM 协作。 需要更细粒度控制训练数据拼接、分布式策略和后训练组合。
Unsloth:快速 LoRA/QLoRA 原型

Unsloth 更适合把“底座加载、量化、LoRA 注入、训练、保存、导出”压成短路径。它常用于先验证数据清洗、chat template、target modules 和学习率是否合理,再决定是否迁移到更完整的分布式训练系统。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel
 
model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/Meta-Llama-3.1-8B",
    # 先给出训练时最大上下文长度,后续 RoPE/attention 路径会按这个预算准备。
    max_seq_length=2048,
    # 4bit 加载把底座显存压下来,训练时只更新 LoRA 分支。
    load_in_4bit=True,
)
 
model = FastLanguageModel.get_peft_model(
    model,
    # rank 决定 LoRA 容量;小数据先从 8/16 起步,避免 adapter 过度记忆训练集。
    r=16,
    # target_modules 要和模型内部线性层命名匹配。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    # alpha 控制 LoRA 更新幅度,通常和 rank 配套调。
    lora_alpha=16,
    # LoRA 分支 dropout 用来缓和小数据 SFT 的过拟合。
    lora_dropout=0.05,
)
 
dataset = load_dataset("trl-lib/Capybara", split="train")
 
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=SFTConfig(
        output_dir="out_unsloth_lora",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        num_train_epochs=1,
        logging_steps=10,
    ),
)
 
trainer.train()
 
# 保存 adapter 目录,后续可以继续挂到底座上评估、切换或 merge。
model.save_pretrained("out_unsloth_adapter")
tokenizer.save_pretrained("out_unsloth_adapter")
 
# 如果目标是本地推理格式,再单独做 merged/GGUF 导出。
# 导出前应先确认目标后端支持的量化类型和 tokenizer 模板。
model.save_pretrained_gguf("out_gguf", tokenizer, quantization_method="q4_k_m")
LLaMA-Factory:YAML 管理训练配方

LLaMA-Factory 的核心使用方式是把训练语义写进 YAML。模型、数据集、chat template、微调方法、LoRA 参数、batch、学习率、输出目录都在同一个文件中声明,CLI 只负责读取配置并启动训练。

qwen_lora_sft.yaml
YAML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# 声明训练阶段;sft 表示监督微调。
stage: sft
 
# 声明使用 LoRA,而非全参更新。
finetuning_type: lora
 
# 底座模型目录或 Hub id。
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
 
# 训练数据集名称;具体字段映射由 LLaMA-Factory 的 dataset_info 管理。
dataset: my_sft_dataset
 
# 模板决定 system/user/assistant 消息如何渲染成训练文本。
template: qwen
 
# 训练输出目录,通常保存 adapter、日志和配置快照。
output_dir: saves/qwen2_5_7b/lora/sft
 
# all 表示由工具按常见规则覆盖主要线性投影层。
lora_target: all
 
# LoRA rank 控制 adapter 容量。
lora_rank: 16
 
# alpha 控制 LoRA 更新量缩放。
lora_alpha: 32
 
# 长上下文会显著增加显存,先用明确上限约束训练预算。
cutoff_len: 2048
 
# 单卡微批大小;长上下文 LoRA 常从 1 起步。
per_device_train_batch_size: 1
 
# 用多个微步累计成更大的有效 batch。
gradient_accumulation_steps: 8
 
# LoRA 常用学习率通常高于全参微调。
learning_rate: 2.0e-4
 
# bf16 在支持的 GPU 上通常比 fp16 更稳。
bf16: true
 
# 最小闭环先跑 1 轮;正式实验按验证指标决定停点。
num_train_epochs: 1.0

Shell
1
llamafactory-cli train qwen_lora_sft.yaml

这类配置化路线的关键是把“训练语义”和“运行方式”分开。YAML 描述这次训练到底做什么;Accelerate、DeepSpeed 或 FSDP 配置描述它如何在硬件上运行。排查问题时两份配置都要看,不能只看 Python 调用栈。

Axolotl:复杂配方与 sample packing

Axolotl 更偏向把复杂微调配方显式化。它适合需要 QLoRA、FSDP、sample packing、多数据集混合和后训练路线组合的场景。配置能力越强,对字段语义的要求越高,尤其是 chat_template、 sequence_len、 sample_packing、 adapter 这些字段。

axolotl_qlora.yml
YAML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 底座模型身份;tokenizer、config 与权重加载都围绕它展开。
base_model: meta-llama/Meta-Llama-3-8B
 
# 明确 tokenizer 工厂类型,避免模型仓库默认配置不完整时解析失败。
tokenizer_type: AutoTokenizer
 
# qlora 表示底座按低比特加载,同时只训练 LoRA 分支。
adapter: qlora
 
# 打开 4bit 权重量化,降低底座显存占用。
load_in_4bit: true
 
datasets:
  # 数据来源;type 决定字段如何被解释成监督样本。
  - path: tatsu-lab/alpaca
    type: alpaca
 
# 模板决定消息如何被渲染成模型实际看见的训练文本。
chat_template: llama3
 
# sample packing 会把多个短样本拼进同一序列,提高 token 利用率。
sample_packing: true
 
# 验证阶段常关闭 packing,便于逐条样本解释指标。
eval_sample_packing: false
 
# 上下文长度直接决定显存、吞吐和可学习的长程依赖。
sequence_len: 4096
 
# 单卡微批大小,长上下文 QLoRA 通常从 1 起步。
micro_batch_size: 1
 
# 梯度累积把有效 batch 拉大,减少小 batch 更新噪声。
gradient_accumulation_steps: 8
 
# adapter 参数少,常用比全参微调更高的学习率。
learning_rate: 2.0e-4
 
# 本次运行的 adapter、日志和配置快照目录。
output_dir: ./outputs/llama3-8b-qlora

Shell
1
axolotl train axolotl_qlora.yml

sample_packing 追求的是 token 利用率。开启后,一个训练序列里可能包含多个样本,吞吐通常更好,但逐条样本排错、label 对齐和 loss 解释会更难。正式实验应同时保留一份不 packing 的小验证集,用来排查模板和答案边界。

语言模型强化学习

语言模型强化学习(Reinforcement Learning for Language Models)发生在 SFT 之后。SFT 让模型学会按照示例回答,RL 后训练让模型在生成过程中接受奖励信号约束,把“回答是否有用、是否正确、是否符合格式、是否遵守工具协议”转成可优化的目标。

从 SFT 到 RL 后训练

SFT(Supervised Fine-Tuning)使用人工写好的目标答案做交叉熵训练。训练样本通常是一个 prompt 和一个 target answer,模型只需要提高目标 token 的似然。RL 后训练的输入仍然可以是 prompt,优化对象变为模型自己生成的回答集合、奖励函数、参考模型约束和策略优化算法。

一条典型链路如下:

1
2
3
4
5
6
7
8
9
10
11
12
Pre-training
  基座模型从大规模语料中学习语言建模能力。
 
SFT
  模型学习指令跟随、回答格式、基础安全边界和任务模板。
 
RL post-training
  模型对同一个 prompt 采样多个回答,由 reward function 或 reward model 打分。
  训练器根据 reward、KL、advantage 更新 policy。
 
Serving
  部署时使用更新后的 policy;reference、critic、reward 只在训练阶段参与。

这一步的核心变化是训练目标从“模仿一个答案”变为“在可采样空间里提高高奖励回答的概率,同时限制策略漂移”。因此 RL 后训练同时牵涉推理引擎、分布式训练、奖励计算、样本队列和 checkpoint 恢复。

Rollout:语言模型里的行动轨迹

Rollout 指策略模型在当前参数下对 prompt 进行采样,生成一个或多个完整 response,并记录训练需要的中间量。对于机器人控制,轨迹由状态、动作、奖励组成;对于语言模型,动作就是生成下一个 token,状态就是已有上下文,轨迹就是 prompt 后面的一串 response tokens。

语言模型 rollout 的最小结构可以写成:

字段 含义 工程用途
prompt 输入问题、工具调用上下文或多轮对话历史。 用于构造模型输入,也用于 reward function 解析任务条件。
response 当前 policy 采样出的回答 token 序列。 reward、KL、log probability、长度统计都围绕 response 计算。
old_log_probs 采样当时当前 policy 对每个 response token 的对数概率。 PPO/GRPO 更新时计算概率比值,避免策略一次更新过大。
ref_log_probs reference model 对同一 response token 的对数概率。 计算 KL 惩罚,限制 RL 把模型推离 SFT 分布太远。
reward / score reward function 或 reward model 给出的标量或 token-level 分数。 决定这条回答应该被鼓励还是压低概率。
advantage 相对同组回答、baseline 或 value function 的优势值。 把“绝对分数”转成“这条回答比预期好多少”。
response_mask 标记哪些位置属于 response,哪些位置只是 padding。 loss、KL、reward 聚合时排除 prompt 和 padding。

Rollout 的质量直接决定 RL 训练质量。采样温度过高会带来大量低质量噪声,温度过低会导致同一个 prompt 的多个 response 太相似,GRPO/RLOO 这类组内比较方法就缺少有效差异。工程上常同时监控 response 长度、成功率、格式错误率、重复率、KL 和 reward 方差。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 一个极简 rollout 数据结构,用于说明 RL 训练器真正需要保存什么。
# 实际框架会把这些字段放进 TensorDict、DataProto 或 Ray object store。
rollout_batch = {
    # prompt_tokens 是已经应用 chat template 后的输入 token。
    # 使用 token 作为跨组件边界,可以避免训练端和推理端模板不一致。
    "prompt_tokens": prompt_tokens,
 
    # response_tokens 是 policy 采样出的动作序列。
    # 语言模型 RL 中,每个 token 都是一次动作。
    "response_tokens": response_tokens,
 
    # old_log_probs 固定为采样时 policy 的概率。
    # PPO/GRPO 更新时用它计算 ratio,判断新 policy 是否偏离过大。
    "old_log_probs": old_log_probs,
 
    # ref_log_probs 来自冻结 reference model。
    # 它提供 SFT 分布的锚点,避免 reward 把模型推向奇怪输出。
    "ref_log_probs": ref_log_probs,
 
    # rewards 可以是每条样本一个标量,也可以展开成 token-level reward。
    # 数学题、代码题常用规则奖励;开放式偏好任务常用 reward model。
    "rewards": rewards,
 
    # response_mask 让 loss 只作用在 response token 上。
    # prompt token 只是条件,不应该被策略梯度当作模型动作更新。
    "response_mask": response_mask,
}
Policy、Reference、Reward、Critic

语言模型 RL 是多角色系统。一个完整 PPO/RLHF 系统至少包含 policy、reference、reward,使用 PPO 时通常还包含 critic。

角色 训练状态 作用
Policy / Actor 可训练 真正被更新的语言模型。rollout、log probability 和最终部署都围绕它。
Reference 冻结 通常是 SFT 后的初始模型。计算 KL 约束,防止 policy 为了 reward 牺牲语言质量。
Reward Function / Reward Model 通常冻结 把回答映射成分数。可由规则、单元测试、人工偏好模型、格式检查、多目标加权组成。
Critic / Value Model 可训练 估计当前状态下的期望回报,为 PPO 的 GAE 提供 baseline,降低梯度方差。
Rollout Engine 服务组件 vLLM、SGLang 或 HF generation。负责高吞吐生成和新权重同步。

Reference 和 reward 的存在解释了 RLHF 框架为什么比 SFT 框架复杂。SFT 只需要训练模型前向、反向和优化器;RLHF 还要在训练期间反复调用推理引擎生成样本,并把 actor 的新权重同步给 rollout engine。

Reward Function:把业务目标变成训练信号

奖励函数(Reward Function)是 RL 后训练最关键的设计点。它决定模型优化什么,也决定模型会钻哪些空子。奖励函数可以来自规则、模型、工具执行结果、人工偏好、检索一致性或多目标加权。

奖励来源 适用任务 主要风险
规则奖励 数学答案、JSON 格式、工具协议、正则可验证输出。 覆盖不完整时容易 reward hacking。
单元测试 / 沙箱 代码生成、SQL、Agent 工具调用。 执行成本高,超时和安全隔离必须严格。
Reward Model 开放式问答、摘要、偏好对齐。 reward model 偏差会被 policy 放大。
LLM-as-a-Judge 难以写硬规则的质量评估。 成本、稳定性、提示词泄漏和 judge 偏差。
多目标加权 正确性、格式、简洁性、安全性同时约束。 权重尺度不一致时,某个目标会吞掉其他目标。

奖励函数需要先单元测试,再接入训练。单元测试至少覆盖正确答案、错误答案、格式正确但内容错误、格式错误但内容接近、空回答、超长回答和恶意输出。训练中还要记录分项 reward,不能只看总 reward,否则很难定位 reward hacking。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import re
 
 
def math_box_reward(response: str, label: str) -> float:
    # 规则奖励先解析最终答案,不直接比较整段推理文本。
    # CoT 推理链存在大量等价表达,直接字符串比较会错误惩罚正确解法。
    match = re.search(r"\\boxed\{([^}]*)\}", response)
 
    # 没有遵守 \boxed{} 输出协议时给 0 分。
    # 这会同时训练模型的答案格式和可解析性。
    if match is None:
        return 0.0
 
    # strip 只去掉答案两侧空白,避免空格导致可验证答案误判。
    pred = match.group(1).strip()
 
    # label 应该来自数据集的标准答案字段,避免从 prompt 中再次抽取。
    # 这样 reward function 的输入边界清晰,便于离线单元测试。
    gold = label.strip()
 
    # 二值奖励适合可验证任务;开放式任务通常需要连续分数或分项奖励。
    return 1.0 if pred == gold else 0.0
Advantage、Baseline 与 Credit Assignment

原始 reward 只说明一条回答的得分,advantage 说明这条回答相对预期好多少。策略梯度更新依赖 advantage:正 advantage 提高该 response token 的概率,负 advantage 降低概率。

\[A_t = R_t - b_t\]

其中 \(A_t\) 是第 \(t\) 个 token 或轨迹位置的优势值,读作 advantage;\(R_t\) 是从该位置开始的回报;\(b_t\) 是 baseline,可以来自 critic、同组样本均值或规则估计。baseline 不改变期望梯度方向,但能显著降低方差。

GRPO 常在同一个 prompt 下采样多个 response,用组内均值做 baseline:

\[A_i = r_i - \frac{1}{G}\sum_{j=1}^{G} r_j\]

其中 \(G\) 是同一 prompt 的采样条数,读作 group size;\(r_i\) 是第 \(i\) 条 response 的奖励。这个形式不需要单独训练 critic,适合数学、代码、格式检查等规则奖励任务。

PPO、GRPO、RLOO、REINFORCE++ 与 DPO
算法 核心思路 适用场景
PPO 用 old policy 和 new policy 的概率比做 clipped objective,配合 critic 估计 value。 经典 RLHF、reward model 训练链路完整、需要稳定收敛的场景。
GRPO 同一 prompt 采样多条 response,用组内相对分数构造 advantage,减少 critic 依赖。 RLVR、数学/代码等可验证任务,尤其适合多样本采样。
RLOO Leave-one-out baseline。每条样本用同组其他样本均值作为 baseline。 同 prompt 多采样、希望降低组内估计偏差的场景。
REINFORCE++ 基于 REINFORCE 的语言模型后训练改造,常配合 baseline、KL 和长度/格式控制。 规则奖励、推理任务、希望简化 value model 的场景。
DPO 直接使用偏好对,不做在线 rollout;优化 chosen 相对 rejected 的概率差。 偏好数据充足、希望避免在线 RL 系统复杂度的场景。

在线 RL 框架通常覆盖 PPO、GRPO、RLOO、REINFORCE++。DPO 更接近离线偏好优化,常放在 TRL、LLaMA-Factory、OpenRLHF 的 non-RL 训练入口里。工程选型时先判断是否需要在线采样:需要 rollout、工具交互、沙箱执行或动态 reward,就进入在线 RL;只有 chosen/rejected 数据,DPO/IPO/KTO 往往更简单。

监控指标与失败模式
指标 正常含义 异常信号
reward/score 任务目标正在改善。 快速上升但人工评估下降,通常是 reward hacking。
KL policy 与 reference 的距离保持在预算内。 KL 爆炸表示策略漂移;KL 长期接近 0 表示学习太弱。
entropy 生成分布保留一定探索。 entropy 快速坍缩通常对应模板化、重复或过早收敛。
response_length 回答长度处在任务合理范围。 长度持续增长可能是模型学会用长答案骗 reward。
grad_norm 梯度规模可控。 突然尖峰常见于 reward 尺度失控、坏 batch 或数值溢出。
clipfrac PPO 中被 clip 的 token 比例。 过高表示学习率或 advantage 尺度过大,更新被大量截断。
rollout throughput 推理引擎采样效率。 GPU 空转通常来自 vLLM batch、tensor parallel、sleep mode 或 Ray placement 配错。

RL 训练健康性不能只看 reward。reward、KL、长度、格式错误率和人工抽检需要一起读。reward 上升且 KL 稳定、长度稳定、格式错误下降,才更接近真实改进。

框架选择边界
框架 优势 边界
TRL Hugging Face 生态内最容易本地验证 DPO、PPO、GRPO、SFT。 大规模 Ray/vLLM/DeepSpeed 一体化能力较弱。
OpenRLHF Ray + vLLM + DeepSpeed 组合明确,Hybrid Engine 和 Agent Paradigm 适合大规模在线 RL。 命令参数多,资源拓扑和版本组合需要严格管理。
verl HybridFlow 把 RL 控制流与模型计算流解耦,适合研究新算法和复杂 rollout。 Hydra 配置体系较大,首次接入需要理解 DataProto 和 WorkerGroup。
OpenRLHF / verl + vLLM rollout 吞吐高,适合多样本采样和长 response。 权重同步、KV cache、GPU memory utilization 是主要工程瓶颈。
OpenRLHF 详解

OpenRLHF 面向大语言模型 RLHF/RLVR 后训练。它把 Ray、vLLM、DeepSpeed、Transformers 组合成一条在线 RL 管线:Ray 做角色调度,vLLM 做 rollout,DeepSpeed 做 actor/critic 训练,Transformers 负责模型加载与 Hugging Face checkpoint 兼容。

官方文档结构与源码入口
入口 阅读重点 工程意义
quick_start.rst 安装、层级 CLI、Qwen3 RLVR 首跑命令。 确认版本边界和最小可运行训练链路。
architecture.rst Ray、vLLM、DeepSpeed、NCCL 的职责划分。 理解为什么 RLHF 训练需要多个角色和多个引擎。
agent_paradigm.rst Single-turn、multi-turn、token-in-token-out、算法解耦。 理解工具调用和环境交互如何接入同一 RL loss。
hybrid_engine.rst sleep mode、colocate_all、NCCL weight sync。 解决小集群上生成与训练互相空转的问题。
async_training.rst async queue、partial rollout、off-policy correction。 在吞吐优先场景下重叠 rollout 与训练。
examples/python/math_reward_func.py Python reward function 的输入输出协议。 把规则奖励接到在线 RLVR。
examples/python/agent_func.py MultiTurnAgentExecutor 与 AgentInstanceBase。 把多轮环境、工具调用、反馈回路接入 rollout。
安装与版本边界

OpenRLHF 官方推荐在 NVIDIA PyTorch 容器中安装,原因是 RLHF 同时依赖 CUDA、NCCL、vLLM、FlashAttention、DeepSpeed 和 Ray。容器能减少二进制依赖冲突。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 使用 NVIDIA runtime 启动 PyTorch 容器,确保容器内能访问 GPU。
# --shm-size 给 Ray、DataLoader、vLLM 共享内存留空间。
docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \
    -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:25.11-py3 bash
 
# 基础镜像里可能预装与 vLLM / flash-attn ABI 不兼容的包。
# 先卸载这些包,可以避免安装 OpenRLHF extras 时出现二进制冲突。
pip uninstall xgboost transformer_engine flash_attn pynvml opencv-python-headless -y
 
# 只安装核心训练入口,适合先阅读 SFT/RM/DPO 或自带推理后端的场景。
pip install openrlhf
 
# 安装 vLLM 集成,是在线 RL rollout 的常用组合。
pip install openrlhf[vllm]
 
# 跟随较新的 vLLM 版本,适合需要新推理特性但愿意处理兼容性问题的环境。
pip install openrlhf[vllm_latest]
 
# 增加 ring attention 与 Liger kernel,通常用于长上下文和训练吞吐优化。
pip install openrlhf[vllm,ring,liger]
 
# 源码安装便于阅读 examples、修改 reward、调试 trainer 内部实现。
git clone https://github.com/OpenRLHF/OpenRLHF.git
cd OpenRLHF
pip install -e .

版本选择上,vLLM 是 rollout 性能核心,DeepSpeed 是 ZeRO-3 和大模型训练核心。使用 Muon 优化器时还要满足 DeepSpeed 的版本要求。生产环境不要混用随机升级的 vLLM、DeepSpeed、CUDA 镜像,应把容器 tag、pip freeze、训练脚本和 checkpoint 绑定保存。

核心架构:Ray + vLLM + DeepSpeed

OpenRLHF 的核心思想是角色拆分。Actor、Reference、Reward、Critic 和 vLLM engine 都可以独立映射到 GPU 资源池,也可以通过 Hybrid Engine 放在同一批 GPU 上分时运行。

组件 职责 常见瓶颈
Ray 提交 job、启动远程 actor、分配 GPU、管理 role placement。 placement 配错导致某些 GPU 空闲,或主节点内存被 driver 吃满。
vLLM 批量生成 rollout,维护 KV cache,接受 actor 权重同步。 KV cache 不够、batch 太小、tensor parallel 配置不匹配。
DeepSpeed Actor/Critic 的 ZeRO-3 训练、参数/梯度/优化器状态分片。 ZeRO stage、offload、micro batch、gradient checkpointing 组合不当。
NCCL 训练端与 vLLM 端权重同步、GPU 间通信。 网络/NCCL 环境变量错误会卡在同步或 reduce。
Transformers 加载 Hugging Face 模型、tokenizer、chat template、checkpoint 结构。 special tokens、chat template 和 checkpoint 目录不一致。
角色模型:Actor / Reference / Reward / Critic / vLLM

OpenRLHF 的层级 CLI 把不同角色的参数放在不同前缀下,例如 --actor.*、 --reward.*、 --vllm.*、 --ds.*。这能避免大规模训练脚本里参数归属混乱。

前缀 控制对象 示例
--actor 可训练 policy 模型、优化器、梯度检查点。 --actor.model_name_or_path, --actor.adam.lr
--ref 冻结 reference policy 的节点和 GPU 数。 --ref.num_nodes, --ref.num_gpus_per_node
--reward reward model 或 Python reward function。 --reward.remote_url
--vllm rollout engine 数量、TP、显存预算、权重同步。 --vllm.num_engines, --vllm.tensor_parallel_size
--rollout 采样 batch、每 prompt 采样数、最长生成长度。 --rollout.batch_size, --rollout.n_samples_per_prompt
--algo advantage、KL、动态过滤、off-policy correction。 --algo.advantage.estimator, --algo.kl.init_coef
--ds DeepSpeed ZeRO、精度、packing、sleep mode。 --ds.zero_stage, --ds.param_dtype
--ckpt checkpoint 保存、恢复、HF 导出。 --ckpt.path, --ckpt.save_hf
QuickStart 命令逐行解读

下面基于官方 Qwen3-4B RLVR 示例重写为工程注释版。命令展示的是“如何把 Ray job、actor、reward、dataset、vLLM、算法、DeepSpeed 和 checkpoint 接起来”。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Ray head 节点负责接收 job、维护集群资源视图和调度远程 worker。
# --num-gpus 4 告诉 Ray 当前节点有 4 张 GPU 可分配给 OpenRLHF 的各个角色。
ray start --head --node-ip-address 0.0.0.0 --num-gpus 4
 
# 通过 Ray Jobs API 提交训练,训练进程由 Ray 在集群内启动。
# 这样提交脚本可以和训练 worker 解耦,适合多节点或容器环境。
ray job submit --address="http://127.0.0.1:8265" \
 
  # working_dir 会被 Ray 打包分发到 worker,reward 脚本和本地源码才能被远程进程找到。
  --runtime-env-json='{"working_dir": "/openrlhf"}' \
 
  # 双横线后面是 Ray job 内真正执行的命令。
  # train_ppo_ray 是 OpenRLHF 在线 RL 的 Ray 入口。
  -- python3 -m openrlhf.cli.train_ppo_ray \
 
  # actor 是被训练的 policy,初始权重来自 Qwen3-4B-Thinking。
  --actor.model_name_or_path Qwen/Qwen3-4B-Thinking-2507 \
 
  # reward.remote_url 可以指向 HTTP reward 服务,也可以指向本地 Python 文件。
  # 数学 RLVR 常先使用规则 reward,省掉单独训练 reward model 的步骤。
  --reward.remote_url examples/python/math_reward_func.py \
 
  # prompt_dataset 提供训练 prompt;label_key 提供标准答案或验证目标。
  --data.prompt_dataset zhuzilin/dapo-math-17k \
  --data.input_key prompt \
  --data.label_key label \
 
  # 使用 tokenizer 自带 chat template,保证训练端输入格式与模型指令格式一致。
  --data.apply_chat_template \
 
  # packing_samples 把变长样本打包,提高长上下文场景下的 token 利用率。
  --ds.packing_samples \
 
  # reference policy 使用 1 个节点、4 张 GPU。
  # 它通常冻结,只负责提供 ref_log_probs 和 KL 约束。
  --ref.num_nodes 1 --ref.num_gpus_per_node 4 \
 
  # actor policy 使用同样的 GPU 规模训练。
  # Hybrid Engine 下它会与 vLLM 分时共享资源。
  --actor.num_nodes 1 --actor.num_gpus_per_node 4 \
 
  # 启动 2 个 vLLM engine,每个 engine 用 2 张卡做 tensor parallel。
  # 2 engines * TP=2 正好覆盖 4 张 GPU。
  --vllm.num_engines 2 --vllm.tensor_parallel_size 2 \
 
  # colocate_all 把 Actor、Reference、Reward、Critic、vLLM 放到同一批 GPU 上分时运行。
  --train.colocate_all \
 
  # 控制 vLLM 可用显存比例;值太高会挤压训练侧,值太低会限制 KV cache。
  --vllm.gpu_memory_utilization 0.7 \
 
  # rollout 阶段 DeepSpeed sleep,训练阶段 vLLM sleep,降低同卡共存的显存压力。
  --vllm.enable_sleep --ds.enable_sleep \
 
  # 使用 NCCL 同步 actor 新权重到 vLLM,比 CPU 中转更适合多 GPU。
  --vllm.sync_backend nccl --vllm.enforce_eager \
 
  # reinforce_baseline 使用 baseline 降低 REINFORCE 方差,适合规则奖励 RLVR。
  --algo.advantage.estimator reinforce_baseline \
 
  # KL loss 把 policy 拉回 reference 附近;init_coef 是初始约束强度。
  --algo.kl.use_loss --algo.kl.estimator k2 --algo.kl.init_coef 1e-5 \
 
  # rollout.batch_size 是每轮采样的 prompt 数。
  # n_samples_per_prompt=8 表示每个 prompt 生成 8 条候选回答,用于组内比较。
  --rollout.batch_size 128 --rollout.n_samples_per_prompt 8 \
 
  # train.batch_size 是进入策略更新的样本总量,需要和 rollout 产物规模匹配。
  --train.batch_size 1024 \
 
  # data.max_len 限制 prompt 总长度;rollout.max_new_tokens 限制生成长度。
  # 推理题常需要更长 response,但长度上限必须和显存预算一起调。
  --data.max_len 8192 --rollout.max_new_tokens 4096 \
 
  # ZeRO-3 分片参数、梯度和优化器状态,是大模型训练的基本内存手段。
  --ds.zero_stage 3 --ds.param_dtype bf16 \
 
  # 梯度检查点用额外计算换显存,常用于长上下文或较大 actor。
  --actor.gradient_checkpointing_enable \
 
  # RL 阶段学习率通常比 SFT 更小,避免 reward 噪声导致策略剧烈漂移。
  --actor.adam.lr 5e-7 \
 
  # 输出目录需要和实验名绑定,避免覆盖 SFT 或旧 RL checkpoint。
  --ckpt.output_dir ./exp/Qwen3-4B-Thinking
自定义 Reward Function 逐行解读

OpenRLHF 的 Python reward function 接收完整 query、prompt 和 label,返回 rewards、scores 与 extra_logs。rewards 用于策略优化,scores 常用于动态过滤和指标记录。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import List
 
import torch
 
from openrlhf.utils import extract_boxed_answer, grade_answer
 
 
def reward_func(queries: List[str], prompts: List[str], labels: List[str], **kwargs) -> dict:
    # queries 是 prompt + response 的完整文本,来自 rollout engine 的生成结果。
    # prompts 是原始输入,用于从 query 中切出模型生成的 response。
    # labels 是数据集中的标准答案字段,来自 --data.label_key。
    rewards = []
 
    # zip 保证每条 query、prompt、label 一一对应。
    # reward function 必须保持 batch 内样本顺序,否则 reward 会错配到别的 response。
    for query, prompt, label in zip(queries, prompts, labels):
        # prompt 可能已经被 chat template 改写;只有确认 prompt 在 query 中才切片。
        # 这能避免模板不一致时错误截断 response。
        if isinstance(prompt, str) and prompt in query:
            response = query[len(prompt) :]
        else:
            # 如果无法可靠切出 response,就退化为对完整 query 抽取答案。
            # 这比抛异常更适合长时间训练,但需要在 extra_logs 中监控异常比例。
            response = query
 
        # 数学奖励只检查最终 boxed answer,不直接比较推理链。
        # 推理链可以有多种等价写法,最终答案才是可验证目标。
        pred_answer = extract_boxed_answer(response)
 
        # grade_answer 负责归一化和判等,例如数字格式、LaTeX 表达式等。
        # 具体容错能力由 OpenRLHF 工具函数实现。
        is_correct = grade_answer(pred_answer, label)
 
        # 二值 reward 清晰稳定,适合作为 RLVR 的第一版奖励。
        # 如果任务需要部分分,可以改成分项连续 reward。
        rewards.append(1.0 if is_correct else 0.0)
 
        # 训练早期保留少量打印有助于验证 reward 是否和人类判断一致。
        # 大规模训练时应改成采样日志,避免 stdout 成为瓶颈。
        print(f"[Math Reward] Pred: {pred_answer}, Gold: {label}, Match: {is_correct}")
 
    # OpenRLHF 期望 rewards 是 tensor,便于直接进入分布式训练管线。
    rewards_tensor = torch.tensor(rewards, dtype=torch.float)
 
    # accuracy 是 batch 级别的平均正确率,适合写入 logger 观察训练趋势。
    accuracy = rewards_tensor.mean()
 
    return {
        # rewards 进入 advantage 与 policy loss,是训练信号。
        "rewards": rewards_tensor,
 
        # scores 通常用于动态过滤、评估展示和额外统计。
        # 对简单规则奖励,可以直接与 rewards 保持一致。
        "scores": rewards_tensor,
 
        # extra_logs 会进入日志系统,建议放可解释的分项指标。
        "extra_logs": {
            "math_accuracy": accuracy,
        },
    }
Multi-turn Agent 代码逐行解读

OpenRLHF 的 Agent Paradigm 把“如何收集经验”和“如何更新 policy”拆开。Single-turn 只生成一次回答;multi-turn 会在环境中多步交互,每一步把模型 action、环境反馈、reward 继续拼成 token-level trajectory。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import random
from typing import Any, Dict
 
import torch
 
from openrlhf.utils.agent import AgentInstanceBase, MultiTurnAgentExecutor
 
 
class AgentInstance(AgentInstanceBase):
    # 每个 AgentInstance 对应一个 prompt 的环境实例。
    # 状态变量放在实例上,避免不同 prompt 的交互步数互相污染。
    def __init__(self, *args, **kwargs):
        # 当前已经交互到第几步。
        self.step_idx = 0
 
        # 示例环境随机设置最大步数。
        # 真实工具任务会用任务结束条件、测试结果或环境 done 信号决定。
        self.max_steps = random.randint(1, 3)
 
    async def reset(self, states: dict, **kwargs):
        # reset 在一条新轨迹开始时调用。
        # states 通常包含原始 prompt、label、metadata 和采样参数。
        return {"observation": states["observation"]}
 
    async def step(self, states: dict, **kwargs) -> Dict[str, Any]:
        # step 接收模型刚生成的 action_text,以及上一轮 observation。
        # 多轮 Agent 的核心就是 action -> environment_feedback -> next action。
        observation_text = states["observation_text"]
        action_text = states["action_text"]
        label = states["label"]
 
        # 示例代码没有真正使用 observation_text/action_text/label。
        # 真实任务会在这里运行工具、调用判题器、访问检索系统或比对答案。
        _ = observation_text, action_text, label
 
        # done 表示当前轨迹是否结束。
        # 结束时 reward 才通常给到非零值,未结束步骤更多提供环境反馈。
        done = self.step_idx >= self.max_steps
 
        # 示例用随机 0/1 奖励模拟环境结果。
        # 真实 reward 应来自可复现的规则、模型评分或工具执行结果。
        reward = torch.randint(0, 2, (1,)).float() if done else torch.tensor(0)
 
        # environment_feedback 会作为下一轮模型输入的一部分。
        # 它必须保持模板稳定,否则训练端和推理端会出现分布漂移。
        if done:
            environment_feedback = "\n\nHuman: [CORRECT]\n</s>"
        else:
            environment_feedback = (
                "\n\nHuman: [INCORRECT]\n"
                "Please analyze the issues and try again.\n</s>\n\nAssistant: "
            )
 
        # 每调用一次 step,环境步数递增。
        self.step_idx += 1
 
        return {
            # rewards 用于 advantage 和策略更新。
            "rewards": reward,
 
            # scores 用于动态过滤和统计;简单环境可与 rewards 相同。
            "scores": reward,
 
            # environment_feedback 被追加到上下文,让下一次生成看到环境反馈。
            "environment_feedback": environment_feedback,
 
            # done 控制 multi-turn rollout 是否继续。
            "done": done,
 
            # sampling_params 允许环境按步调整温度、max_tokens 等采样参数。
            "sampling_params": states.get("sampling_params", None),
 
            # extra_logs 记录任务相关指标,便于训练时定位 reward 问题。
            "extra_logs": {"dummy_scores": reward},
        }
 
 
class AgentExecutor(MultiTurnAgentExecutor):
    # MultiTurnAgentExecutor 负责把每个 prompt 包装成 AgentInstance,
    # 并把多步交互结果转成 OpenRLHF 统一的 token-level trajectory。
    def __init__(self):
        super().__init__(AgentInstance)
Hybrid Engine 配置逐行解读

Hybrid Engine 解决 RLHF 的资源空转问题。生成阶段 vLLM 使用 GPU,训练阶段 DeepSpeed 使用 GPU。sleep mode 让二者在同一批 GPU 上分时占用显存。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# colocate_all 让 Actor、Reference、Reward、Critic 与 vLLM 共用同一组 GPU。
# 这降低 GPU 数量需求,但要求 sleep mode 和显存预算配置正确。
--train.colocate_all \
 
# vLLM 的 KV cache 显存预算。
# 值越大,rollout 越能处理长上下文和大 batch;值过大会挤压训练侧。
--vllm.gpu_memory_utilization 0.7 \
 
# generation 阶段 vLLM 醒着,训练阶段 vLLM 释放大部分显存。
--vllm.enable_sleep \
 
# training 阶段 DeepSpeed 醒着,generation 阶段 DeepSpeed 降低显存占用。
--ds.enable_sleep \
 
# actor 权重更新后,通过 NCCL 同步到 vLLM engine。
# RL 训练必须持续同步,否则 rollout 使用的 policy 会越来越旧。
--vllm.sync_backend nccl \
 
# enforce_eager 关闭部分 CUDA graph 行为,常用于提高兼容性和降低图捕获显存压力。
--vllm.enforce_eager

Hybrid Engine 的调参顺序通常是先让脚本稳定跑通,再提高 vllm.gpu_memory_utilization、rollout batch、max tokens。OOM 时先降低 vLLM 显存比例和 micro batch,再考虑分离资源池。

Async 与 Partial Rollout 配置解读

同步训练按 rollout -> train -> rollout 交替执行。Async 让 rollout 和训练通过队列并行,Partial Rollout 进一步在权重同步时暂停和恢复 vLLM 请求。吞吐更高,但样本可能带有轻微 off-policy 噪声。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 开启异步管线后,rollout worker 和 trainer 同时运行。
# 训练端从队列消费样本,rollout 端继续生成下一批。
--train.async_enable \
 
# 队列越深,GPU 越不容易空转,但样本越可能来自旧 policy。
# 初次验证建议从 1 开始。
--train.async_queue_size 1 \
 
# partial rollout 允许 vLLM 在权重同步时暂停正在生成的请求。
# 新权重加载后再恢复生成,提高 rollout 与同步的重叠程度。
--train.partial_rollout_enable \
 
# 异步样本可能偏离当前 policy。
# off-policy correction 用于降低旧样本带来的优化偏差。
--algo.advantage.is_correction_enable \
--algo.advantage.is_correction_type icepop

Async 不适合作为第一轮实验默认配置。先用同步 Hybrid Engine 验证 reward、KL、长度和正确率曲线,再切换 async 观察吞吐收益和收敛差异。

Checkpoint、导出与恢复

OpenRLHF 的 checkpoint 需要同时考虑 DeepSpeed 分片状态和 Hugging Face 可部署权重。训练中断恢复依赖 optimizer、scheduler、dataset progress;上线部署通常需要导出 HF 格式。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# DeepSpeed/Ray 训练状态保存目录,用于断点恢复。
# 这个目录通常包含分片权重、优化器状态、scheduler 状态和训练进度。
--ckpt.path ./exp/Qwen3-4B-Thinking/ckpt \
 
# output_dir 是实验产物根目录,日志、HF 导出和最终模型通常都挂在这里。
--ckpt.output_dir ./exp/Qwen3-4B-Thinking \
 
# 每隔多少 training steps 保存一次 checkpoint。
# RL 训练成本高,save_steps 不宜过大,否则失败后回滚太多 rollout。
--ckpt.save_steps 10 \
 
# 最多保留几个 checkpoint,避免长时间训练把磁盘写满。
--ckpt.max_num 3 \
 
# 同时导出 Hugging Face 格式模型,方便后续用 transformers/vLLM 加载评估。
--ckpt.save_hf \
 
# 从已有 checkpoint 恢复训练时打开。
# 恢复前要保证训练脚本、模型路径、world size 和关键并行配置兼容。
--ckpt.load_enable
Agent Paradigm 与 token-in-token-out

OpenRLHF 的 Agent Paradigm 把执行方式和 RL 算法拆开。执行方式负责产生轨迹,算法负责消费轨迹。single-turn、自定义 reward、multi-turn 工具环境都被统一成 token-level trajectory,后续 PPO、GRPO、RLOO、REINFORCE++ 使用同一套 loss 入口。

维度 Single-turn Multi-turn Agent
执行过程 prompt 生成一次 response,然后 reward 打分。 prompt 进入环境,模型多次 action,环境多次反馈。
配置入口 默认模式,常配合 --reward.remote_url。 --train.agent_func_path 指向 AgentExecutor 文件。
轨迹结构 prompt tokens + response tokens + reward。 observation/action/feedback 多轮拼成一条 token trajectory。
典型任务 数学答案、代码单测、格式校验、偏好奖励。 工具调用、网页环境、代码调试、交互式游戏、搜索增强。

token-in-token-out 的工程价值很高。模型生成结果不先还原成字符串再重新 tokenize,能避免 BOS/EOS、chat template、特殊 token、工具标记在训练端和 rollout 端不一致。多轮 Agent 仍然可以返回文本反馈,但框架最终保存和优化的是 token 级轨迹。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Single-turn RLVR
  prompt tokens
    -> vLLM generate response tokens
    -> reward function reads response text and label
    -> trainer stores response tokens, log_probs, rewards, masks
    -> RL loss updates actor
 
Multi-turn Agent RL
  prompt tokens
    -> model action tokens
    -> environment returns feedback and reward
    -> model reads feedback and emits next action tokens
    -> executor packs all turns into one token-level trajectory
    -> RL loss updates actor with the same algorithm interface
算法入口与层级 CLI

OpenRLHF 0.10.2 之后使用层级 CLI。算法切换主要通过 --algo.advantage.estimator 完成,KL、动态过滤、off-policy correction、batch 和长度预算则通过各自前缀组合。

配置 含义 常见选择
--algo.advantage.estimator 选择 advantage 估计器,也就是在线 RL 的核心算法形态。 gae, reinforce, reinforce_baseline, rloo, group_norm, dr_grpo
--algo.kl.use_loss 把 KL 作为 actor loss 的约束项。 推理/RLVR 中常用,避免 reward 直接吞掉语言质量。
--algo.kl.init_coef KL 初始系数。 从小值开始,结合实际 KL 曲线调。
--algo.dynamic_filtering_enable 过滤全错、全对或超出分数范围的 rollout group。 规则奖励任务常用,减少无学习信号 batch。
--train.dynamic_batch_enable 按 token 数动态组织 batch。 长短 response 混合时降低 OOM 和吞吐抖动。
--train.max_tokens_per_gpu 训练阶段每张 GPU 的 token 上限。 显存控制阈值,优先级高于样本条数。
--rollout.max_tokens_per_gpu rollout 阶段每张 GPU 的 token 上限。 控制 vLLM 侧批量生成规模。
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 切换为 GRPO 风格的组内相对优势。
# OpenRLHF 文档中 group_norm 对应 GRPO 这一类按组归一化的 estimator。
--algo.advantage.estimator group_norm \
 
# 动态过滤会丢弃没有学习信号的 group。
# 例如同一 prompt 的 8 条回答全错或全对,组内 advantage 接近无效。
--algo.dynamic_filtering_enable \
 
# 只保留 score 落在 0 到 1 范围内的样本组。
# 对二值数学 reward,这通常覆盖正常正确率区间。
--algo.dynamic_filtering_range 0.0 1.0 \
 
# 动态 batch 让 batch 按 token 预算组织,样本条数只作为次级约束。
# 长 response 场景下,这是控制显存和吞吐的关键开关。
--train.dynamic_batch_enable \
 
# 训练阶段每张 GPU 最多处理的 token 数。
# OOM 时先降低这个值,比盲目降低学习率更有效。
--train.max_tokens_per_gpu 16192 \
 
# rollout 阶段每张 GPU 最多处理的 token 数。
# 该值影响 vLLM batching 和 KV cache 压力。
--rollout.max_tokens_per_gpu 32768
训练脚本组织模板

OpenRLHF 命令很长,生产脚本应把资源、算法、数据、checkpoint 分区写清楚。下面是一个结构化模板,重点展示每组参数的职责。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env bash
 
# 严格模式让脚本在变量缺失、命令失败、管道失败时立刻停止。
# RL 训练成本高,静默失败会浪费大量 GPU 时间。
set -euo pipefail
 
# 模型、数据、输出目录集中放在顶部,便于实验管理系统覆盖。
ACTOR_MODEL=${ACTOR_MODEL:-Qwen/Qwen3-4B-Thinking-2507}
PROMPT_DATASET=${PROMPT_DATASET:-zhuzilin/dapo-math-17k}
OUTPUT_DIR=${OUTPUT_DIR:-./exp/qwen3_math_rlvr}
 
# Ray head 只需在集群启动时执行一次。
# 本地单机调试可以把它放进脚本;生产集群通常由调度器提前启动。
ray start --head --node-ip-address 0.0.0.0 --num-gpus 4
 
# 使用数组组织参数,避免一条命令无限延伸且难以审查。
DATA_ARGS=(
  # prompt_dataset 是在线 rollout 的问题来源。
  --data.prompt_dataset "$PROMPT_DATASET"
 
  # input_key/label_key 明确数据字段,reward function 才能拿到标准答案。
  --data.input_key prompt
  --data.label_key label
 
  # chat template 保持 prompt 格式与模型训练格式一致。
  --data.apply_chat_template
)
 
ROLLOUT_ARGS=(
  # rollout batch 控制每轮采样多少 prompt。
  --rollout.batch_size 128
 
  # 每个 prompt 多采样,给 GRPO/RLOO/REINFORCE baseline 提供比较对象。
  --rollout.n_samples_per_prompt 8
 
  # 限制生成长度,防止模型靠超长输出拖垮吞吐或骗 reward。
  --rollout.max_new_tokens 4096
)
 
ALGO_ARGS=(
  # 规则奖励任务首选无 critic 或弱 critic 的 estimator 做快速验证。
  --algo.advantage.estimator reinforce_baseline
 
  # KL loss 保留 SFT 模型的语言分布和安全边界。
  --algo.kl.use_loss
  --algo.kl.estimator k2
  --algo.kl.init_coef 1e-5
)
 
ENGINE_ARGS=(
  # ZeRO-3 是大模型在线训练的主要显存手段。
  --ds.zero_stage 3
  --ds.param_dtype bf16
 
  # vLLM 负责采样,TP=2 时一个 engine 跨两张卡。
  --vllm.num_engines 2
  --vllm.tensor_parallel_size 2
  --vllm.sync_backend nccl
 
  # Hybrid Engine 共用 GPU,sleep mode 降低同时驻留显存。
  --train.colocate_all
  --vllm.enable_sleep
  --ds.enable_sleep
)
 
CKPT_ARGS=(
  # ckpt.path 保存可恢复训练状态。
  --ckpt.path "$OUTPUT_DIR/ckpt"
 
  # output_dir 保存实验输出和可选 HF 导出。
  --ckpt.output_dir "$OUTPUT_DIR"
 
  # 频繁保存降低长时间 RL 训练失败后的回滚成本。
  --ckpt.save_steps 10
  --ckpt.max_num 3
  --ckpt.save_hf
)
 
ray job submit --address="http://127.0.0.1:8265" \
  --runtime-env-json='{"working_dir": "/openrlhf"}' \
  -- python3 -m openrlhf.cli.train_ppo_ray \
  --actor.model_name_or_path "$ACTOR_MODEL" \
  --reward.remote_url examples/python/math_reward_func.py \
  "${DATA_ARGS[@]}" \
  "${ROLLOUT_ARGS[@]}" \
  "${ALGO_ARGS[@]}" \
  "${ENGINE_ARGS[@]}" \
  "${CKPT_ARGS[@]}"
OpenRLHF 选型边界

OpenRLHF 适合已经明确要做在线 RLHF/RLVR 的团队,尤其是需要 Ray 调度、vLLM rollout、DeepSpeed ZeRO-3 和 multi-turn agent 的项目。它的优势是角色边界清楚、命令行可组合、Hybrid Engine 对中小 GPU 集群友好。代价是参数多、版本组合敏感、训练脚本需要严格工程化管理。

verl 详解

verl 是面向大语言模型后训练的强化学习框架,核心设计来自 HybridFlow。它把 RL 算法的控制流放在单进程 driver 中,把模型前向、rollout、反向和优化器放在 Ray worker 上执行。这个拆分让研究者能更容易改 PPO/GRPO 主循环,同时复用 FSDP、Megatron-LM、vLLM、SGLang 等计算后端。

官方文档结构与源码入口
入口 阅读重点 工程意义
docs/hybrid_flow.rst 控制流、计算流、driver/worker 拆分。 理解 verl 与一体化多进程 trainer 的差异。
docs/examples/ppo_code_architecture.rst main_ppo、RewardManager、WorkerGroup、ResourcePool。 定位新增算法或新 worker 应该改哪里。
docs/start/quickstart.rst GSM8K PPO 首跑、parquet 数据、model_merger。 建立最小可运行工程链路。
docs/preparation/reward_function.rst custom_reward_function.path/name 和 RewardManager。 接入自定义规则奖励或 reward model。
verl/trainer/main_ppo.py Hydra 入口、Ray 初始化、TaskRunner。 理解训练 job 如何启动和分配角色。
verl/trainer/ppo/ray_trainer.py apply_kl_penalty、compute_advantage、RayPPOTrainer.fit。 理解 rollout 到 actor update 的主循环。
examples/grpo_trainer/run_qwen3_8b_fsdp.sh GRPO、FSDP、vLLM、NPU/GPU 的真实配置组织。 把 Hydra 参数分组管理,避免巨型命令不可维护。
安装与后端选择

verl 官方推荐容器环境。训练后端和 rollout 后端可以独立选择:研究和原型常用 FSDP/FSDP2,超大规模可选 Megatron-LM;rollout 常用 vLLM,也支持 SGLang、TensorRT-LLM 或 Hugging Face 调试路径。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 克隆 verl 源码,便于阅读 trainer、worker 和 examples。
git clone https://github.com/verl-project/verl.git
cd verl
 
# 使用 no-deps 源码安装,依赖通常由官方镜像或项目 requirements 管理。
# 这样可以避免 pip 自动升级 CUDA 相关包导致 ABI 冲突。
pip3 install --no-deps -e .
 
# 预处理 GSM8K 为 parquet。
# verl 的 RLHFDataset 默认从 parquet 读取 prompt、label、data_source 等字段。
python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k
 
# 预下载模型可以提前暴露网络、权限、模型卡依赖问题。
# 真正训练时 actor 和 critic 会通过配置项再次加载该模型。
python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
HybridFlow:控制流与计算流

HybridFlow 把 RL 系统看成两层 dataflow。控制流决定先 rollout、再算 logprob/reward/advantage、再 update actor/critic;计算流决定每个模型操作如何在多 GPU 上执行。verl 让控制流保持单进程 Python 逻辑,计算流交给 Ray worker 和模型引擎。

层次 包含内容 在 verl 中的位置
控制流 rollout 顺序、reward 计算、advantage、actor/critic 更新时机。 RayPPOTrainer.fit、main_ppo.py、算法扩展。
计算流 模型 forward、backward、optimizer、FSDP/Megatron/vLLM 并行。 ActorRolloutRefWorker、TrainingWorker、engine backend。
数据协议 prompt、response、mask、log_probs、reward、advantage。 DataProto、TensorDict、non_tensor_batch。
资源映射 哪些角色放到哪些 GPU pool。 ResourcePoolManager、Role、RayWorkerGroup。
DataProto 与 WorkerGroup

DataProto 是 verl 在 driver 和 worker 之间传递 batch 的核心容器。tensor 字段放在 batch 中,字符串、数据源、ground truth 等非 tensor 字段放在 non_tensor_batch 中。WorkerGroup 对外暴露看起来像本地函数的方法,内部负责把 DataProto 拆分到多个 worker,再收集结果。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# 下面是 verl 文档中 WorkerGroup 调用模式的压缩示意。
# 真实实现由 @register(dispatch_mode=...) 自动处理 split、remote call 和 gather。
 
# data 是一个 DataProto,里面同时包含 tensor batch 与非 tensor metadata。
data = build_prompt_dataproto(batch)
 
# generate_sequences 在 rollout worker 上远程执行。
# driver 只描述控制流,不直接管理每张 GPU 的推理细节。
output = actor_rollout_ref_wg.generate_sequences(data)
 
# compute_log_prob 使用当前 actor 重新计算生成 token 的 log probability。
# PPO/GRPO 需要它和 old_log_probs 共同构造策略更新目标。
old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)
 
# compute_ref_log_prob 使用冻结 reference policy 计算同一批 response 的概率。
# KL 惩罚依赖 current policy 与 reference policy 的差异。
ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)
 
# critic 估计 value,用于 PPO 的 GAE 或 return 计算。
# GRPO 这类算法可以不启用 critic。
values = critic_wg.compute_values(output)
 
# reward worker 或函数式 reward 计算训练信号。
# 输出通常会变成 token_level_scores 或 sequence-level score。
rewards = reward_wg.compute_scores(output)
 
# advantage 在 driver 控制流中计算。
# 这样新增算法可以直接改 Python 逻辑,不必重写底层 FSDP/vLLM worker。
advantages = compute_advantages(values, rewards)
 
# union 把不同 worker 产出的字段合并回同一个 DataProto。
# 后续 actor/critic update 会读取这些字段计算 loss。
output = output.union(old_log_prob)
output = output.union(ref_log_prob)
output = output.union(values)
output = output.union(rewards)
output = output.union(advantages)
 
# actor update 执行策略梯度更新。
# 具体反向传播、梯度裁剪、optimizer step 由 actor worker 的后端实现负责。
actor_rollout_ref_wg.update_actor(output)
 
# critic update 只在启用 critic 的算法中执行。
# GRPO/RLOO/部分 REINFORCE 变体可以省略 value model。
critic_wg.update_critic(output)
RayPPOTrainer 主循环逐行解读

verl 的 ray_trainer.py 中,KL penalty 和 advantage 是连接 reward 与 policy loss 的两个关键函数。下面保留核心逻辑并加工程注释。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def apply_kl_penalty(data, kl_ctrl, kl_penalty="kl"):
    # response_mask 标记 response token 的有效位置。
    # prompt 和 padding 不应该参与 KL reward penalty。
    response_mask = data.batch["response_mask"]
 
    # token_level_scores 是 reward function 或 reward model 产生的原始分数。
    # 它可能只在最后一个 token 非零,也可能已经被展开到每个 token。
    token_level_scores = data.batch["token_level_scores"]
 
    # batch_size 用于更新自适应 KL controller。
    # controller 需要知道本次统计覆盖了多少条样本。
    batch_size = data.batch.batch_size[0]
 
    # old_log_probs 是 actor 对采样 response 的概率。
    # ref_log_prob 是 reference policy 对同一 response 的概率。
    kld = core_algos.kl_penalty(
        data.batch["old_log_probs"],
        data.batch["ref_log_prob"],
        kl_penalty=kl_penalty,
    )
 
    # 只保留 response token 的 KL,避免 prompt token 污染策略约束。
    kld = kld * response_mask
 
    # beta 是当前 KL 惩罚系数,可由 controller 动态调整。
    beta = kl_ctrl.value
 
    # reward 被扣掉 beta * KL。
    # 这让模型提高任务得分时仍受 reference 分布约束。
    token_level_rewards = token_level_scores - beta * kld
 
    # masked_mean 先对每条 response 求平均 KL,再对 batch 求平均。
    # 该指标用于日志和自适应调整 beta。
    current_kl = masked_mean(kld, mask=response_mask, axis=-1)
    current_kl = torch.mean(current_kl, dim=0).item()
 
    # KL controller 根据实际 KL 和目标 KL 调整惩罚强度。
    # KL 过高时 beta 增大,KL 过低时 beta 可以降低。
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
 
    # 后续 advantage 计算读取 token_level_rewards,原始 scores 只保留为未加 KL 的任务分数。
    data.batch["token_level_rewards"] = token_level_rewards
 
    # 返回 metrics 供 logger 展示训练是否偏离 reference。
    metrics = {
        "actor/reward_kl_penalty": current_kl,
        "actor/reward_kl_penalty_coeff": beta,
    }
    return data, metrics

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def compute_advantage(data, adv_estimator, gamma=1.0, lam=1.0, config=None):
    # response_mask 是所有 advantage estimator 的公共输入。
    # 如果前面没有显式计算,这里根据 attention_mask 补出来。
    if "response_mask" not in data.batch.keys():
        data.batch["response_mask"] = compute_response_mask(data)
 
    # GAE 需要 critic values,适合标准 PPO。
    # gamma 控制未来 reward 折扣,lam 控制 bias-variance tradeoff。
    if adv_estimator == AdvantageEstimator.GAE:
        advantages, returns = core_algos.compute_gae_advantage_return(
            token_level_rewards=data.batch["token_level_rewards"],
            values=data.batch["values"],
            response_mask=data.batch["response_mask"],
            gamma=gamma,
            lam=lam,
        )
        data.batch["advantages"] = advantages
        data.batch["returns"] = returns
 
    # GRPO 用同一 prompt 的多条 response 做组内归一化。
    # index 通常是 prompt uid,用来判断哪些 response 属于同一组。
    elif adv_estimator == AdvantageEstimator.GRPO:
        advantages, returns = core_algos.compute_grpo_outcome_advantage(
            token_level_rewards=data.batch["token_level_rewards"],
            response_mask=data.batch["response_mask"],
            index=data.non_tensor_batch["uid"],
            norm_adv_by_std_in_grpo=True,
        )
        data.batch["advantages"] = advantages
        data.batch["returns"] = returns
 
    else:
        # 其他 estimator 通过注册表分发,便于扩展 RLOO、REINFORCE++ 等算法。
        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
 
        # 所有 estimator 至少需要 token reward、mask 和算法配置。
        adv_kwargs = {
            "token_level_rewards": data.batch["token_level_rewards"],
            "response_mask": data.batch["response_mask"],
            "config": config,
        }
 
        # uid 允许 estimator 做组内 baseline 或按 prompt 聚合。
        if "uid" in data.non_tensor_batch:
            adv_kwargs["index"] = data.non_tensor_batch["uid"]
 
        # reward_baselines 可来自数据集、reward model 或外部估计。
        if "reward_baselines" in data.batch:
            adv_kwargs["reward_baselines"] = data.batch["reward_baselines"]
 
        advantages, returns = adv_estimator_fn(**adv_kwargs)
        data.batch["advantages"] = advantages
        data.batch["returns"] = returns
 
    # 返回的 DataProto 已经带上 actor/critic update 所需字段。
    return data
RewardManager 与自定义奖励

verl 的自定义 reward 通过 Hydra 配置接入: custom_reward_function.path 指向 Python 文件, custom_reward_function.name 指向函数名。函数通常接收 data_source、solution_str、ground_truth 和 extra_info。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
    # data_source 用于区分不同数据集。
    # 同一个训练任务混合 GSM8K、MATH、代码题时,reward 逻辑通常不同。
    if data_source == "openai/gsm8k":
        return score_gsm8k(solution_str, ground_truth)
 
    # MATH 数据集常需要 LaTeX 归一化和更复杂的答案等价判断。
    if data_source == "lighteval/MATH":
        return score_math(solution_str, ground_truth)
 
    # extra_info 可以携带测试用例、rubric、样本难度或工具参数。
    # 代码任务通常会从 extra_info 中读取 hidden tests 或 sandbox 配置。
    if extra_info and extra_info.get("task_type") == "code":
        return run_sandbox_tests(solution_str, extra_info["tests"])
 
    # 未覆盖的数据源直接报错,避免静默给 0 分导致训练信号损坏。
    raise NotImplementedError(f"No reward function for data_source={data_source}")

Shell
1
2
3
4
5
6
7
# 指定自定义 reward 文件路径。
# 该文件会被 trainer 加载,函数必须在 worker 可访问的路径下。
custom_reward_function.path=/workspace/rewards/math_reward.py \
 
# 指定 reward 文件中的函数名。
# 如果函数就叫 compute_score,可以省略 name,使用默认入口。
custom_reward_function.name=compute_score
QuickStart 命令逐行解读

verl 的 QuickStart 使用 Hydra 覆盖参数。命令行里的 a.b.c=value 会覆盖配置树中的对应字段。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# PYTHONUNBUFFERED=1 让日志实时输出,训练异常时不用等缓冲区刷新。
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
 
  # 训练集和验证集使用 parquet,便于保存 prompt、answer、data_source、extra_info。
  data.train_files=$HOME/data/gsm8k/train.parquet \
  data.val_files=$HOME/data/gsm8k/test.parquet \
 
  # train_batch_size 是每次 PPO 外层迭代消费的 prompt 数。
  # 真实进入 actor update 的样本数还会乘以 rollout.n。
  data.train_batch_size=256 \
 
  # prompt 和 response 分开限制长度,便于控制 KV cache 和训练显存。
  data.max_prompt_length=512 \
  data.max_response_length=512 \
 
  # actor、rollout、reference 共用同一个模型路径。
  # actor 会训练,reference 通常冻结,rollout 用于高吞吐采样。
  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
 
  # actor 学习率控制 policy 更新幅度。
  # RL 阶段通常小于 SFT 学习率,降低 reward 噪声放大风险。
  actor_rollout_ref.actor.optim.lr=1e-6 \
 
  # PPO mini batch 控制一次 rollout batch 被切成多少策略更新子批次。
  actor_rollout_ref.actor.ppo_mini_batch_size=64 \
 
  # 每张 GPU 的 micro batch 控制显存峰值。
  # OOM 时优先降低它,再考虑总 batch 或模型规模。
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
 
  # rollout.name=vllm 使用 vLLM 作为生成后端。
  actor_rollout_ref.rollout.name=vllm \
 
  # log_prob micro batch 控制重新计算 response token 概率时的显存。
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
 
  # tensor_model_parallel_size 控制 vLLM 推理时一个模型切到几张 GPU。
  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
 
  # vLLM KV cache 显存比例,值越大 rollout 能容纳更长 response 或更多并发。
  actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
 
  # reference policy 只计算 ref_log_prob,也需要 micro batch 控制显存。
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
 
  # critic 学习率通常可以高于 actor,因为 value 拟合不直接改变生成分布。
  critic.optim.lr=1e-5 \
 
  # critic.model.path 通常与 actor 初始模型一致,也可以使用更小 value model。
  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
 
  # critic micro batch 控制 value model 训练显存。
  critic.ppo_micro_batch_size_per_gpu=4 \
 
  # KL 系数限制 policy 偏离 reference。
  # 系数过大模型学不动,过小容易 reward hacking。
  algorithm.kl_ctrl.kl_coef=0.001 \
 
  # console logger 适合本地首跑;长期实验通常加 wandb。
  trainer.logger=console \
 
  # 跳过训练前验证可以更快暴露训练链路问题。
  trainer.val_before_train=False \
 
  # 单节点单卡配置;多卡时同步修改 n_gpus_per_node 和并行参数。
  trainer.n_gpus_per_node=1 \
  trainer.nnodes=1 \
 
  # checkpoint 和验证频率。
  # RL 训练波动大,早期建议更频繁保存和验证。
  trainer.save_freq=10 \
  trainer.test_freq=10 \
 
  # total_epochs 控制遍历数据集轮数,不等于 PPO update 总数。
  trainer.total_epochs=15 2>&1 | tee verl_demo.log
GRPO / FSDP / vLLM 配置逐行解读

verl 的真实训练脚本通常把 Hydra 参数分组为 DATA、MODEL、ACTOR、ROLLOUT、REF、TRAINER。这样比一条超长命令更可维护,也便于不同硬件分支复用。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# 失败即退出,未定义变量即报错,管道中任何一步失败都会失败。
# 训练脚本应默认严格模式,避免静默跑出错误实验。
set -xeuo pipefail
 
# DEVICE 默认通过 torch_npu 探测 NPU,否则使用 GPU。
# 这让同一份脚本可以覆盖 NVIDIA GPU 和 Ascend NPU 环境。
DEVICE=${DEVICE:-$(python3 -c 'import torch_npu' 2>/dev/null && echo npu || echo gpu)}
 
# rollout backend 默认 vLLM,也可以改成 sglang 或 trtllm。
# 选择后端会影响可用参数、吞吐和硬件兼容性。
INFER_BACKEND=${INFER_BACKEND:-vllm}
 
# 模型路径可以是 Hugging Face repo,也可以是本地 checkpoint。
MODEL_PATH=${MODEL_PATH:-Qwen/Qwen3-8B}
 
# train_batch_size 是每轮外层训练消费的 prompt 数。
# GRPO 中每个 prompt 还会采样 rollout_n 条 response。
train_batch_size=${TRAIN_BATCH_SIZE:-1024}
 
# ppo_mini_batch_size 控制一次 rollout 产物切成多少策略更新批次。
ppo_mini_batch_size=${PPO_MINI_BATCH_SIZE:-256}
 
# max_response_length 对推理题很关键,过短会截断思考,过长会拖垮吞吐。
max_response_length=${MAX_RESPONSE_LENGTH:-2048}
 
# rollout_n 是每个 prompt 的采样条数。
# GRPO 依赖组内比较,rollout_n 太小会让 advantage 估计不稳定。
rollout_n=${ROLLOUT_N:-5}
 
DATA=(
    # 使用 GRPO advantage estimator,不需要单独训练 critic。
    algorithm.adv_estimator=grpo
 
    # KL 放在 actor loss 中处理,保留 reward 的原始任务语义。
    algorithm.use_kl_in_reward=False
 
    # 混合 GSM8K 和 MATH parquet,reward 需要根据 data_source 区分逻辑。
    data.train_files="['$HOME/data/gsm8k/train.parquet', '$HOME/data/math/train.parquet']"
    data.val_files="['$HOME/data/gsm8k/test.parquet', '$HOME/data/math/test.parquet']"
 
    # 控制 prompt batch,真实 response 数量约为 train_batch_size * rollout_n。
    data.train_batch_size=${train_batch_size}
 
    # prompt/response 长度上限同时影响 vLLM KV cache 和 FSDP 训练显存。
    data.max_prompt_length=${MAX_PROMPT_LENGTH:-1024}
    data.max_response_length=${max_response_length}
 
    # 过长 prompt 直接过滤,避免训练中途被截断破坏题意。
    data.filter_overlong_prompts=True
    data.truncation='error'
)
 
MODEL=(
    # actor、rollout、reference 的初始权重。
    actor_rollout_ref.model.path="$MODEL_PATH"
 
    # remove padding 能减少无效 token 计算,提高长短样本混合时的吞吐。
    actor_rollout_ref.model.use_remove_padding=True
 
    # gradient checkpointing 用计算换显存,适合 8B 以上模型和长 response。
    actor_rollout_ref.model.enable_gradient_checkpointing=True
)
 
ACTOR=(
    # actor 学习率决定 policy 更新幅度。
    actor_rollout_ref.actor.optim.lr=${ACTOR_LR:-1e-6}
 
    # GRPO/PPO update 的 mini batch 大小。
    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size}
 
    # 动态 batch 按 token 数组织 micro batch,减少长短样本造成的显存波动。
    actor_rollout_ref.actor.use_dynamic_bsz=True
 
    # 每张 GPU 的最大训练 token 数,是防 OOM 的核心阈值。
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU:-24576}
 
    # KL loss 约束 actor 不要偏离 reference 过远。
    actor_rollout_ref.actor.use_kl_loss=True
    actor_rollout_ref.actor.kl_loss_coef=${KL_LOSS_COEF:-0.001}
    actor_rollout_ref.actor.kl_loss_type=low_var_kl
 
    # entropy 系数控制探索;推理任务常设为 0,依赖 rollout sampling 提供多样性。
    actor_rollout_ref.actor.entropy_coeff=${ENTROPY_COEFF:-0}
)
 
ROLLOUT=(
    # rollout 后端选择 vLLM/SGLang/TensorRT-LLM。
    actor_rollout_ref.rollout.name=${INFER_BACKEND}
 
    # rollout tensor parallel size 控制推理模型切分。
    actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP:-2}
 
    # vLLM 显存比例主要用于 KV cache。
    actor_rollout_ref.rollout.gpu_memory_utilization=${ROLLOUT_GPU_MEM_UTIL:-0.6}
 
    # 每个 prompt 生成多条 response,GRPO 用这些 response 做组内相对优势。
    actor_rollout_ref.rollout.n=${rollout_n}
 
    # log_prob 动态 batch 避免长 response 重新算概率时 OOM。
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU:-24576}
)
 
REF=(
    # reference 只算 ref_log_prob,动态 batch 同样可以降低显存尖峰。
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU:-24576}
 
    # reference 冻结后可 offload 参数,给 actor/rollout 腾显存。
    actor_rollout_ref.ref.fsdp_config.param_offload=True
)
 
TRAINER=(
    # balance_batch 按序列长度平衡 batch,减少某张 GPU 被长样本拖慢。
    trainer.balance_batch=True
 
    # 同时输出 console 和 wandb,适合长期训练留存曲线。
    trainer.logger='["console","wandb"]'
 
    # project_name 和 experiment_name 决定 checkpoint 与日志目录层级。
    trainer.project_name=${PROJECT_NAME:-verl_grpo_gsm8k_math}
    trainer.experiment_name=${EXPERIMENT_NAME:-qwen3_8b_grpo_vllm_fsdp}
 
    # 资源规模,Ray 会据此分配 worker。
    trainer.n_gpus_per_node=${NGPUS_PER_NODE:-8}
    trainer.nnodes=${NNODES:-1}
 
    # checkpoint 和验证频率。
    trainer.save_freq=${SAVE_FREQ:-20}
    trainer.test_freq=${TEST_FREQ:-5}
    trainer.total_epochs=${TOTAL_EPOCHS:-15}
)
 
# 数组展开可以保持每组参数独立维护,同时传给 Hydra 入口。
python3 -m verl.trainer.main_ppo \
    "${DATA[@]}" \
    "${MODEL[@]}" \
    "${ACTOR[@]}" \
    "${ROLLOUT[@]}" \
    "${REF[@]}" \
    "${TRAINER[@]}" \
    "$@"
Checkpoint、合并与评估

verl 默认把 checkpoint 保存到 checkpoints/${trainer.project_name}/${trainer.experiment_name}。FSDP 保存的是分片状态,部署前通常需要用 verl.model_merger 合并成 Hugging Face 目录。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# local_dir 指向某个 global_step 下 actor 的 FSDP 分片 checkpoint。
# 这个目录用于恢复训练,不一定能直接被 transformers/vLLM 加载。
LOCAL_DIR=checkpoints/${trainer_project}/${trainer_experiment}/global_step_1/actor
 
# target_dir 是合并后的 Hugging Face 格式目录。
# 部署、离线评估和继续 SFT 通常使用这个目录。
TARGET_DIR=checkpoints/${trainer_project}/${trainer_experiment}/global_step_1/actor/huggingface
 
# backend=fsdp 告诉 merger 按 FSDP 分片格式读取权重。
# 如果训练使用 Megatron 后端,backend 需要切换到对应合并路径。
python3 -m verl.model_merger merge \
    --backend fsdp \
    --local_dir "$LOCAL_DIR" \
    --target_dir "$TARGET_DIR"

评估时应固定 decoding 配置,并同时看任务分数、KL、长度和格式错误率。RL checkpoint 之间不能只按 reward 选最优;reward hacking 会让训练分数优于人工质量。

main_ppo 与 TaskRunner 代码解读

verl 的训练入口由 Hydra 管配置,Ray 管分布式任务。 main_ppo.py 里的 driver 负责启动 Ray、构造 TaskRunner,并由 TaskRunner 组装 worker、resource pool 和 trainer;模型前向、反向和生成由远程 worker 执行。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
    # Hydra 把 YAML 和命令行覆盖合成 config。
    # 训练脚本中的 a.b.c=value 最终都会落到这个对象里。
    auto_set_device(config)
 
    # 兼容旧 reward 配置,统一迁移到当前 RewardManager 接口。
    # 这一步让老实验脚本不必一次性全部重写。
    config = migrate_legacy_reward_impl(config)
 
    # 真正启动 Ray 和 PPO 训练逻辑。
    run_ppo(config)
 
 
def run_ppo(config, task_runner_class=None) -> None:
    # 如果外部没有提前 ray.init,这里创建本地或集群 Ray runtime。
    if not ray.is_initialized():
        # 默认 runtime_env 会设置 tokenizer、NCCL、vLLM 等环境变量。
        # 这些变量必须随着 Ray worker 一起分发。
        default_runtime_env = get_ppo_ray_runtime_env()
 
        # 用户可以通过 config.ray_kwargs.ray_init 覆盖 Ray 初始化参数。
        ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
        runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
 
        # transfer_queue 用于某些高吞吐数据传输路径。
        # 开启后需要把环境变量注入所有 Ray worker。
        if config.transfer_queue.enable:
            runtime_env_vars = runtime_env_kwargs.get("env_vars", {})
            runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1"
            runtime_env_kwargs["env_vars"] = runtime_env_vars
 
        # OmegaConf.merge 让默认 runtime_env 与用户覆盖项合并。
        runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
 
        # ray.init 接收普通 dict,这里把 OmegaConf 转回容器。
        ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
        ray.init(**OmegaConf.to_container(ray_init_kwargs))
 
    # TaskRunner 是单进程 driver actor。
    # 官方建议不要调度到 Ray head,因为它会持有配置、数据迭代器和控制状态。
    if task_runner_class is None:
        task_runner_class = ray.remote(num_cpus=1)(TaskRunner)
 
    # 创建远程 TaskRunner,然后调用 run(config)。
    # ray.get 阻塞直到整个 PPO job 完成或失败。
    runner = task_runner_class.remote()
    ray.get(runner.run.remote(config))
Role、Worker 与 ResourcePool 映射

verl 的角色映射决定“谁负责 actor/rollout/ref/critic/reward”,资源池映射决定“这些角色放到哪些 GPU 上”。理解这层映射后,FSDP、Megatron、vLLM 或 SGLang 的替换才不会混乱。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class TaskRunner:
    def __init__(self):
        # role_worker_mapping 把逻辑角色映射到 Ray remote worker class。
        # 例如 ActorRolloutRef 使用 ActorRolloutRefWorker。
        self.role_worker_mapping = {}
 
        # mapping 把逻辑角色映射到资源池 id。
        # 资源池 id 再对应具体节点和 GPU 数量。
        self.mapping = {}
 
    def add_actor_rollout_worker(self, config):
        from verl.single_controller.ray import RayWorkerGroup
        from verl.trainer.ppo.ray_trainer import Role
        from verl.workers.engine_workers import ActorRolloutRefWorker
 
        # ActorRolloutRefWorker 是统一 worker。
        # 它可以只做 actor,也可以融合 actor + rollout + reference。
        actor_rollout_cls = ActorRolloutRefWorker
        ray_worker_group_cls = RayWorkerGroup
 
        # LoRA PPO 中 reference 往往可以由 base model + adapter 状态表示。
        # 因此需要根据 LoRA 配置判断 reference 是否融合进 actor worker。
        lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
        if lora_rank <= 0:
            lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
 
        # lora_adapter_path 存在时,也说明 reference 与 actor 的关系不同于全量模型。
        ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
 
        # 需要 reference 且不能融合时,使用 ActorRolloutRef 角色。
        # 否则只注册 ActorRollout 角色,减少不必要的 worker 拆分。
        if need_reference_policy(config) and not ref_in_actor:
            role = Role.ActorRolloutRef
        else:
            role = Role.ActorRollout
 
        # ray.remote 把 worker class 转成 Ray actor class。
        self.role_worker_mapping[role] = ray.remote(actor_rollout_cls)
 
        # actor/rollout/ref 默认放到 global_pool。
        # 后续 ResourcePoolManager 会把 global_pool 映射到具体 GPU 列表。
        self.mapping[role] = "global_pool"
 
        # 返回 worker class 和 worker group class,供 trainer 初始化时使用。
        return actor_rollout_cls, ray_worker_group_cls
 
    def init_resource_pool_mgr(self, config):
        # global_pool 是默认资源池 id。
        global_pool_id = "global_pool"
 
        # 每个节点分配 n_gpus_per_node 张卡,共 nnodes 个节点。
        # 例如 2 节点 * 8 卡会得到 [8, 8]。
        resource_pool_spec = {
            global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
        }
 
        # reward model 可以使用独立资源池,避免占用 actor/rollout 的 GPU。
        if config.reward.reward_model.enable_resource_pool:
            reward_pool = [config.reward.reward_model.n_gpus_per_node] * config.reward.reward_model.nnodes
            resource_pool_spec["reward_pool"] = reward_pool
 
        from verl.trainer.ppo.ray_trainer import ResourcePoolManager
 
        # ResourcePoolManager 最终把 Role -> pool id -> GPU 拓扑串起来。
        return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
Reward 数据流代码解读

verl 的 RewardManager 从 DataProto 中取出 response、ground_truth、data_source,解码 response 后调用具体 reward function。这个设计支持混合数据集:同一个 batch 里不同 data_source 可以走不同评分函数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class RewardManager:
    def __init__(self, tokenizer, num_examine=0, compute_score_fn=None):
        # tokenizer 用于把 response token 解码成文本。
        # function reward 通常基于字符串解析答案或执行测试。
        self.tokenizer = tokenizer
 
        # num_examine 控制打印多少条样本用于人工检查 reward 是否合理。
        # 训练时不要打印过多,否则日志会成为瓶颈。
        self.num_examine = num_examine
 
        # compute_score_fn 是真正的业务 reward 函数。
        # 它可以按 data_source 分发到 GSM8K、MATH、代码沙箱等逻辑。
        self.compute_score_fn = compute_score_fn
 
    def __call__(self, data):
        # responses 是模型生成的 token,不包含 prompt 部分。
        responses = data.batch["responses"]
 
        # attention_mask / response_mask 用于确定有效 token。
        # padding 位置不能参与 reward 或 loss 聚合。
        response_mask = data.batch["response_mask"]
 
        # ground_truth 和 data_source 是非 tensor metadata。
        # 它们通常来自 parquet 文件中的列。
        ground_truth = data.non_tensor_batch["ground_truth"]
        data_source = data.non_tensor_batch["data_source"]
 
        # token_level_scores 的形状与 responses 对齐。
        # 很多规则奖励只在最后一个有效 token 写入分数。
        token_level_scores = torch.zeros_like(responses, dtype=torch.float32)
 
        for i in range(len(responses)):
            # valid_response_tokens 去掉 padding,只保留模型真实生成的 token。
            valid_response_tokens = responses[i][response_mask[i].bool()]
 
            # 解码 response 文本给 reward function 使用。
            # skip_special_tokens 能减少 EOS/PAD 对正则解析的干扰。
            solution_str = self.tokenizer.decode(valid_response_tokens, skip_special_tokens=True)
 
            # 调用业务 reward。
            # extra_info 可携带测试用例、rubric、难度、答案解析等字段。
            score = self.compute_score_fn(
                data_source=data_source[i],
                solution_str=solution_str,
                ground_truth=ground_truth[i],
                extra_info=None,
            )
 
            # 把序列级 reward 放到最后一个有效 token 上。
            # 这样 policy loss 可以沿 response_mask 聚合,同时保持 reward 稀疏语义。
            last_token_idx = response_mask[i].nonzero()[-1]
            token_level_scores[i, last_token_idx] = score
 
        # 返回 token-level score,后续 KL penalty 和 advantage 会继续处理它。
        return token_level_scores
模型合并后的评估模板

RL checkpoint 合并成 Hugging Face 格式后,应单独写评估脚本,固定模型路径、解码参数和数据切分。不要直接拿训练 reward 日志当最终结论。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
 
# 合并后的 HF 目录,来自 verl.model_merger 或 OpenRLHF --ckpt.save_hf。
model_dir = "checkpoints/project/run/global_step_100/actor/huggingface"
 
# tokenizer 必须与训练时一致,尤其是 chat template 和 special tokens。
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 
# 加载 actor 权重用于离线评估。
# torch_dtype=bfloat16 与训练精度保持一致,减少数值分布差异。
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
 
# 评估 prompt 应使用和训练一致的消息结构。
messages = [{"role": "user", "content": "Solve: 12 + 30 = ?"}]
 
# apply_chat_template 复用模型 tokenizer 中的官方模板。
# add_generation_prompt=True 表示把输入停在 assistant 应该开始回答的位置。
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
 
# return_tensors="pt" 生成 PyTorch tensor,随后移动到模型所在设备。
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
# 评估时固定 decoding 参数,保证不同 checkpoint 可比较。
# do_sample=False 对应 greedy decoding,适合先做确定性回归测试。
with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )
 
# 只解码新生成部分,避免 prompt 混入答案解析。
response_ids = output_ids[0, inputs["input_ids"].shape[1] :]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)
verl 选型边界

verl 适合需要频繁改 RL 算法主循环、接入复杂 reward、试验 FSDP/Megatron/vLLM/SGLang 后端组合的团队。它的核心优势是控制流清晰,研究者可以在 driver 侧修改 rollout、advantage 和 update 顺序;底层模型并行仍由 worker 后端负责。代价是 Hydra 配置树较深,DataProto、WorkerGroup、Role、ResourcePoolManager 需要先建立概念模型。

DeepSpeed 详解

DeepSpeed 的工程接口由两部分构成:一部分是 launcher + 分布式初始化,负责把“单机脚本”变成“多进程多卡训练”;另一部分是 DeepSpeedEngine + JSON 配置,负责把显存分片(ZeRO)、offload、混合精度、梯度累积、checkpoint 等能力落在可复现的配置上。本节围绕安装、启动、 deepspeed.initialize、 ds_config.json、ZeRO Stage 1/2/3、offload、checkpoint,以及与 Transformers/Accelerate 的集成路径展开,重点给出配置与代码的对应关系。

安装与环境验证
基础安装

DeepSpeed 的最小安装路径是先安装 PyTorch,再安装 DeepSpeed。DeepSpeed 包含若干 C++/CUDA 扩展(ops),默认采用 JIT 方式在运行期编译加载,因此环境里通常需要可用的编译链与 ninja。

Shell
1
2
3
4
pip install deepspeed
 
# 可选:Transformers 侧一次性装好集成依赖
pip install "transformers[deepspeed]"
环境报告(ds_report)

安装完成后优先跑环境报告,确认“哪些 ops 可用、哪些会在运行时编译、CUDA/通信栈是否匹配”。这个步骤在排查安装或性能差异时比直接跑训练更高效。

Shell
1
2
3
4
ds_report
 
# 等价入口
python -m deepspeed.env_report
预编译 ops(可选)

默认 JIT 编译适合研发迭代;在固定镜像或需要减少“首次运行抖动”的场景里,可以在安装期预编译部分或全部 ops。DeepSpeed 提供一组 DS_BUILD_* 环境变量控制构建范围。

Shell
1
2
3
4
5
# 尝试构建所有 ops(只会构建与当前机器兼容的部分)
DS_BUILD_OPS=1 pip install deepspeed
 
# 只构建某一类 op(示例:FusedLamb)
DS_BUILD_FUSED_LAMB=1 pip install deepspeed

预编译全部 ops 可能耗时较长,可通过并行编译加速:

Shell
1
DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8"
Launcher 与进程启动
单机多卡

DeepSpeed launcher 的默认约定是“一进程一 GPU”。launcher 会为脚本注入 --local_rank,脚本侧需要能解析这个参数并把当前进程绑定到对应 GPU。

Shell
1
2
# 单机 8 卡
deepspeed --num_gpus=8 train.py --deepspeed --deepspeed_config ds_config.json
多机

多机训练通常由 launcher 读取 hostfile(节点列表与每节点 slots),并在每个节点上拉起相同脚本。hostfile 格式依赖部署系统(裸机/Slurm/K8s),工程上常见做法是先让调度系统分配机器与 GPU,再由 DeepSpeed 或 torchrun 建立通信。

hostfile(示例)
1
2
node0 slots=8
node1 slots=8

Shell
1
deepspeed --hostfile=hostfile train.py --deepspeed --deepspeed_config ds_config.json
脚本侧参数解析

DeepSpeed 提供 deepspeed.add_config_arguments 把 --deepspeed 与 --deepspeed_config 等参数接入到自定义 argparse 中。

Python
1
2
3
4
5
6
7
8
import argparse
import deepspeed
parser = argparse.ArgumentParser()
# local_rank 由 launcher 注入;脚本保留这个参数是为了兼容常见多卡启动方式。
parser.add_argument("--local_rank", type=int, default=-1)
# add_config_arguments 会补上 --deepspeed / --deepspeed_config 等标准参数。
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
deepspeed.initialize 与训练循环
最小接入骨架

deepspeed.initialize 是训练入口:负责(必要时)初始化 torch distributed,并返回一个可直接用于 forward/backward/step 的 DeepSpeedEngine。配置文件里的 optimizer/scheduler/dataloader 也可以被 DeepSpeed 构造与管理。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import torch
import deepspeed
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    model = MyModel()
 
    # optimizer 既可以在 ds_config.json 里声明,也可以像这里一样由代码显式传入覆盖配置。
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=model.parameters(),
    )
    for batch in train_loader:
        # engine 负责前向、ZeRO 状态管理和梯度累积;不要再手工调原始 optimizer。
        loss = model_engine(batch)
        model_engine.backward(loss)
        # 何时真正更新参数由 DeepSpeedEngine 按配置决定。
        model_engine.step()
if __name__ == "__main__":
    main()
manual backward:什么时候该用 engine.scale(loss)

标准路径当然是 model_engine.backward(loss)。但真实项目里经常会遇到一种情况:损失核心是你先做了额外组合、裁剪、蒸馏、或跨模型共享,再想手工调 loss.backward()。这时就不能直接把 DeepSpeedEngine 绕过去。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# initialize 会返回 DeepSpeedEngine。
# 训练脚本后面真正应该交互的是 engine,而不再是原始 model / optimizer。
engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    # model_parameters 告诉 DeepSpeed:哪些参数需要建立优化器与 ZeRO 状态。
    model_parameters=model.parameters(),
    # config 决定混合精度、ZeRO、梯度累积等运行时策略。
    config=ds_config,
)
 
for batch in train_loader:
    # 这里直接调用 engine,而非原始 model。
    # 原因是 forward 期间 DeepSpeed 还要管理 ZeRO gather/repartition 与 mixed precision。
    loss = engine(batch)
 
    # 如果不用 engine.backward(loss),就要先显式 scale。
    scaled_loss = engine.scale(loss)
    scaled_loss.backward()  # backward 作用在已经过 loss scaling 的张量上,数值路径才完整。
 
    # 只有到真正的梯度累积边界,这一步才会触发权重更新。
    engine.step()  # 非边界步只会累积梯度;边界步才会执行 optimizer update。

这里的关键点核心是混合精度与 loss scaling 的责任边界。直接调用 loss.backward() 而不经过 engine.backward() 或 engine.scale(loss),会把 DeepSpeed 的数值路径直接绕开。

梯度累积边界:is_gradient_accumulation_boundary 的用途

DeepSpeed 的 step() 每轮都可以调用,但并非每轮都会真正更新参数。很多训练逻辑只应发生在“真实更新边界”上,例如 EMA、外部 scheduler、吞吐统计、checkpoint tag 递增,以及某些 callback。判断这个边界的标准接口就是 is_gradient_accumulation_boundary()。

Python
1
2
3
4
5
6
7
8
9
10
for step, batch in enumerate(train_loader):
    loss = engine(batch)
    engine.backward(loss)
 
    if engine.is_gradient_accumulation_boundary():
        # 这类逻辑只应在“真实参数更新”那一步执行。
        ema.update()
        step_counter += 1
 
    engine.step()

没有这层判断时,最常见的问题是日志步数、学习率步数和真实优化步数错位。表面上训练还在跑,实际上很多围绕“step”的外围系统都已经对不上了。

多模型共享 loss:蒸馏、RLHF 与协同训练的骨架

DeepSpeed 并不要求一个进程里只能有一个 engine。蒸馏、actor-critic、reward model 协同训练,常常会在同一轮里维护多个 engine,然后围绕一个共享 loss 做反向传播。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# teacher 与 student 分别维护自己的 DeepSpeedEngine。
# 它们可以有不同的 ZeRO stage、精度策略和 checkpoint 目录。
teacher_engine, _, _, _ = deepspeed.initialize(
    model=teacher,
    model_parameters=teacher.parameters(),
    config=teacher_ds_config,
)
student_engine, _, _, _ = deepspeed.initialize(
    model=student,
    model_parameters=student.parameters(),
    config=student_ds_config,
)
 
for batch in train_loader:
    with torch.no_grad():
        # teacher 只负责给出目标分布,因此放在 no_grad 里,避免无意义的显存与反向开销。
        teacher_logits = teacher_engine(batch["input_ids"])
 
    student_logits = student_engine(batch["input_ids"])  # student 才是要更新的对象。
    loss = distill_loss(student_logits, teacher_logits)  # 上层任务 loss 由两边输出共同定义。
 
    # 共享 loss 由各自 engine 负责回传到各自参数分片。
    student_engine.backward(loss)
    student_engine.step()

一旦进入这种多 engine 场景,训练脚本就不能再把“模型对象”“优化器对象”“checkpoint 目录”混为一谈。每个 engine 都有自己的 ZeRO 状态、优化器与恢复语义,而共享的只是上层任务目标。

分布式初始化的边界

当脚本里已经显式调用了 torch.distributed.init_process_group,DeepSpeed 侧应改为 deepspeed.init_distributed 或直接移除显式初始化,让 deepspeed.initialize 自动完成分布式初始化。多重初始化是常见的 hang 根源。

配置与代码的覆盖规则

DeepSpeed 的核心覆盖规则是:配置文件定义默认行为,显式传入的 Python 对象覆盖配置。例如,当在 deepspeed.initialize 里传入 optimizer 时,会覆盖 ds_config.json 里 optimizer 段落的定义。

ds_config.json:把“训练策略”编码进配置
批大小三件套与推导关系

DeepSpeed 将 batch size 拆为三项参数:有效 batch( train_batch_size)、每卡 micro-batch( train_micro_batch_size_per_gpu)、梯度累积步数( gradient_accumulation_steps)。三者满足:

\[B = b \\times g \\times N\]

其中 \(B\) 对应 train_batch_size,\(b\) 对应 train_micro_batch_size_per_gpu,\(g\) 对应 gradient_accumulation_steps,\(N\) 是参与训练的 GPU 数量(即 world size)。

工程上通常只显式指定其中两个,剩下一个由 DeepSpeed 推导;这样可以减少多机扩容时的人工改动。

一份可跑的最小配置
ds_config_min.json
JSON
1
2
3
4
5
6
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": { "stage": 2 }
}
配置-代码对照表
配置项 DeepSpeed 行为 代码侧需要做什么
train_micro_batch_size_per_gpu 定义每次 forward/backward 的 micro-batch 大小 DataLoader 提供的 batch 必须与该值一致,或让上层框架(Trainer/Accelerate)保持一致
gradient_accumulation_steps 定义多少个 micro-step 后做一次参数更新 训练循环仍按 “每个 batch 一次 backward”,DeepSpeedEngine 内部按配置决定何时 step
fp16.enabled / bf16.enabled 启用混合精度与 loss scaling(若需要) 脚本不再手写 AMP 也能跑通;若与外部 AMP 同时启用,需明确由谁负责 autocast/scaler
optimizer DeepSpeed 构造优化器(可选) 如果在 deepspeed.initialize 显式传入 optimizer,则会覆盖该段配置
scheduler DeepSpeed 构造并在每步自动 step(可选) 当 scheduler 由 DeepSpeed 管理时,脚本不应额外调用 scheduler.step()
ZeRO:Stage 1/2/3 与 Offload

ZeRO 的全称是 Zero Redundancy Optimizer,中文可译作“零冗余优化器”。名字里的 redundancy 指的是数据并行里的重复状态:每张 GPU 都拿到完整参数、完整梯度、完整优化器状态。ZeRO 并没有改变模型的数学结构,也没有改变损失函数;它改变的是训练状态在不同 GPU 上的存放方式。

设模型参数量为 \(P\),data-parallel world size 为 \(N\)。如果按 bf16/fp16 参数、bf16/fp16 梯度、AdamW 的 FP32 一阶动量 \(m\)、FP32 二阶动量 \(v\)、FP32 master weights 粗略估算,单卡仅模型状态就可能接近:

\[\mathrm{memory}_{\mathrm{DP}}\approx P\cdot(2+2+4+4+4)\ \mathrm{bytes}=16P\ \mathrm{bytes}\]

其中 \(2\) bytes 来自 bf16/fp16 参数,另一个 \(2\) bytes 来自梯度,三个 \(4\) bytes 分别来自 AdamW 的 \(m\)、\(v\) 和 FP32 master weights。这个估算还没有算激活、临时 buffer、通信 bucket、KV cache 或框架额外开销。

ZeRO 的分片收益可以粗略理解为:

\[\mathrm{memory}_{\mathrm{ZeRO\text{-}1}}\approx P\cdot(2+2)+\frac{P\cdot(4+4+4)}{N}\]

Stage 1 只把优化器状态切到 \(N\) 张卡上;参数和梯度仍然每卡完整保存。

\[\mathrm{memory}_{\mathrm{ZeRO\text{-}2}}\approx P\cdot2+\frac{P\cdot(2+4+4+4)}{N}\]

Stage 2 再把梯度也切开;参数仍然每卡完整保存。

\[\mathrm{memory}_{\mathrm{ZeRO\text{-}3}}\approx \frac{P\cdot(2+2+4+4+4)}{N}\]

Stage 3 连参数也分片,单卡模型状态占用最低。代价是每层计算前需要把当前层参数 all-gather 到可计算形态,计算后再释放或重新分片。显存节省来自“少存重复状态”,额外成本来自“更多通信与更复杂的参数生命周期管理”。

把它类比成多人搬书更直观。普通数据并行像每个人都背着整套书,再一起读同一章节;ZeRO-1 先把笔记本和索引卡分给不同人保管;ZeRO-2 再把每章批注也分开保管;ZeRO-3 连书页本身也分开保管,读到某一章时临时把相关页面凑齐,读完再分回去。越往后越省背包空间,但传递页面的协调成本越高。

问题 优先尝试 原因
模型能加载,训练时 optimizer state 顶爆显存 ZeRO Stage 1 / Stage 2 AdamW 状态和梯度通常是第一波显存大头,Stage 1/2 的收益直接且通信代价相对温和。
模型参数本身太大,完整权重难以放进单卡 ZeRO Stage 3 或 PyTorch FSDP 必须切参数本身,Stage 1/2 已经不够。
ZeRO-3 仍然放不下,GPU 显存极端紧张 CPU / NVMe offload 把部分参数或优化器状态搬到主存/磁盘,牺牲带宽换容量。
初始化模型时就 OOM deepspeed.zero.Init 让模型构造阶段就按 ZeRO-3 语义分片,避免先完整创建再切分。
zero.Init:超大模型先分片再初始化

当模型大到“连在单卡或单进程里完整实例化一次都做不到”时,光靠训练阶段的 ZeRO-3 已经不够,因为程序会先死在 Python 对象创建与参数分配这一步。 deepspeed.zero.Init 的作用,就是在模型构造阶段就按 ZeRO-3 语义分片参数,把“初始化时的峰值内存”也压下来。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import deepspeed
 
ds_config = {
    "train_micro_batch_size_per_gpu": 1,
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 3,
        "offload_param": {"device": "cpu", "pin_memory": True},
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
    },
}
 
with deepspeed.zero.Init(
    # 直接复用训练态配置,让初始化与正式训练沿用同一套 ZeRO 语义。
    config_dict_or_path=ds_config,
    remote_device="cpu",   # 先在 CPU 侧构造,再按 ZeRO-3 规则搬运/分片
    enabled=True,          # 显式打开 zero.Init;便于按条件分支决定是否启用
):
    # 这里的模型参数不会先完整落到单卡显存里再分片。
    model = MyHugeTransformer(...)
 
# 模型对象创建完成后,再进入正常的 DeepSpeed initialize。
engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    # 把真正需要训练的参数列表交给 DeepSpeed 建优化器与 ZeRO 状态。
    model_parameters=model.parameters(),
    config=ds_config,
)

它解决的是“模型创建阶段的峰值内存”,并非训练吞吐本身。工程上只有在模型规模真的碰到初始化内存墙时才需要上这一层;普通 7B/13B 微调不必默认引入。

Stage 1:分片优化器状态

ZeRO Stage 1 将优化器状态(例如 Adam 的一阶/二阶动量与 FP32 master 权重)在 data-parallel ranks 之间分片,降低“优化器状态显存/内存”的重复开销。对模型参数量中等但 optimizer state 占用成为瓶颈的训练很直接。

ds_zero_stage1.json
JSON
1
2
3
4
5
6
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": { "stage": 1 }
}
Stage 2:再分片梯度

ZeRO Stage 2 在 Stage 1 的基础上将梯度也分片。它通常在“模型能放下,但训练状态占用过高”或“希望进一步扩大 batch/seq_len”时成为默认选择;相较 Stage 3,它在通信与实现复杂度上更温和。

ds_zero_stage2.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
Stage 3:再分片参数(最省显存)

ZeRO Stage 3 进一步把模型参数也分片,使得“单卡显存”主要由激活与少量 shard 状态构成,从而把可训练模型尺度推到显存上限之外。代价是更多通信与参数聚合/分片的复杂性,checkpoint 与恢复也会更敏感。

ds_zero_stage3.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
Offload:把状态搬到 CPU/NVMe

offload 的目标是继续压缩 GPU 显存占用。常见组合是:ZeRO-2 offload optimizer states(CPU)与 ZeRO-3 offload params/optimizer(CPU 或 NVMe)。offload 会把瓶颈从显存转移到带宽与延迟,因此通常需要配合更细的 micro-batch、更高的梯度累积与更强的通信/计算重叠。

ds_zero3_cpu_offload.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "offload_param": { "device": "cpu", "pin_memory": true },
    "offload_optimizer": { "device": "cpu", "pin_memory": true }
  }
}
ZeRO-3 动态模块:leaf modules 的必要性

ZeRO-3 会在 forward/backward 期间自动 gather 与 repartition 参数。对普通静态 Transformer 结构,这条路很顺;但对 MoE、动态路由器或不同 rank 可能走到不同子模块的网络,自动 gather 就可能在不同 rank 上走出不同分支,最后直接演变为 hang 或 all-gather 不一致。DeepSpeed 给这类模块准备了 leaf module 机制。

Python
1
2
3
4
5
6
7
8
9
10
from deepspeed.utils import set_z3_leaf_modules_by_suffix
 
# 把动态 expert 模块标成 leaf,告诉 ZeRO-3:到这里就当成一个整体 gather。
set_z3_leaf_modules_by_suffix(model, ["experts"])
 
engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=ds_config,
)

把模块标为 leaf 的含义,核心是“ZeRO-3 在这里停止继续向内递归协调参数收集”。对动态专家层而言,这往往是避免不同 rank 走出不同 gather 路径的关键。

GatheredParameters 与外部参数访问

ZeRO-3 下,参数默认处于分片态。只要你打算在模块外部直接读一个参数,或把某个参数借给别的模块 forward 使用,就必须显式告诉 DeepSpeed 你要把它临时 gather 出来。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import deepspeed
from deepspeed.zero import GatheredParameters
 
engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=ds_config,
)
 
# 在 owner module 之外读取 lm_head.weight,就要显式 gather。
with GatheredParameters([engine.module.lm_head.weight], modifier_rank=0):
    if engine.global_rank == 0:
        snapshot = engine.module.lm_head.weight.detach().cpu().clone()

如果某个参数会在别的模块 forward 中被外部引用,还需要考虑 register_external_parameter() 这类接口。工程语义很简单:让 DeepSpeed 知道“这个参数虽属于 A 模块,但 B 模块 forward 也会碰它”。否则参数分片生命周期和实际访问路径会脱节。

safe_get / safe_set:分片状态的调试与修复接口

ZeRO 把参数、梯度和优化器状态分散到不同 rank 之后,普通的 param.grad、 state_dict() 式直觉就不再可靠。DeepSpeed 为此提供了 safe_get_full_grad、 safe_get_full_fp32_param、 safe_set_full_grad 等调试接口,用于在正确的阶段把分片状态安全收拢。

这些接口的工程价值主要出现在三类场景:排查某一层梯度是否真的在更新;在 ZeRO-3 下做权重修补或规则化操作;以及把大模型训练中的数值异常定位到具体参数张量。如果还沿用普通单卡时代的“直接 print param.grad”,你看到的常常只是一个局部 shard,而非完整状态。

运行时显存接口:empty_partition_cache 与 offload_states

大模型训练越来越常见的一个模式,是训练、评估、生成、蒸馏交替发生。同一进程既要做训练态 ZeRO,又要临时切进生成态或检查点转换。DeepSpeed 这时提供的核心是一些很朴素但非常关键的运行时接口:

  • empty_partition_cache() 用来释放 ZeRO 在分片过程中缓存的一些参数副本。
  • offload_states(...) 与 reload_states(...) 用来在 CPU/GPU 之间搬运优化器状态或参数状态,为临时生成窗口腾显存。

它们不会把一个本来放不下的训练 magically 变成能放下,但在“训练与生成共进程”这类高压场景里,经常能决定一个流程是稳定切换,还是偶发性 OOM。

ZeRO 高频配置项:哪些旋钮最常真正去调

真正的 ZeRO 调优很少只改 stage。更高频的是下面这些和通信形态、梯度布局、offload 粒度直接相关的参数。它们通常出现在官方教程、OpenRLHF/DeepSpeedExamples 以及大模型训练配置里。

命令/API/函数
overlap_comm

说明
让部分通信与计算重叠,减少纯等待时间。对通信占比高的多卡训练很常见。

示例

JSON
1
2
3
4
5
6
{
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true
  }
}

命令/API/函数
contiguous_gradients

说明
把梯度布局整理得更连续,减少碎片化与部分通信/拷贝开销。通常和 Stage 2/3 一起出现。

示例

JSON
1
2
3
4
5
6
{
  "zero_optimization": {
    "stage": 2,
    "contiguous_gradients": true
  }
}

命令/API/函数
reduce_scatter / allgather_bucket_size / reduce_bucket_size

说明
控制 ZeRO 通信的聚合方式与 bucket 粒度。它们共同决定“通信开始得多早”“每次通信包有多大”。

示例

JSON
1
2
3
4
5
6
7
8
{
  "zero_optimization": {
    "stage": 2,
    "reduce_scatter": true,
    "allgather_bucket_size": 5e8,
    "reduce_bucket_size": 5e8
  }
}

命令/API/函数
offload_param / offload_optimizer

说明
把参数或优化器状态搬到 CPU/NVMe。它换来更低显存占用,也把瓶颈转移到 PCIe、内存或磁盘带宽。

示例

JSON
1
2
3
4
5
6
7
{
  "zero_optimization": {
    "stage": 3,
    "offload_param": {"device": "cpu", "pin_memory": true},
    "offload_optimizer": {"device": "cpu", "pin_memory": true}
  }
}

调这些参数时,判断标准包括单步速度、峰值显存、吞吐稳定性、step time 抖动,以及 checkpoint 保存/恢复是否开始变脆弱。

ZeRO 配置的第二层语义:粒度、通信 dtype 与 offload 管线

真实项目里,决定成败的往往是下面这些“看起来像小旋钮,实际定义运行时行为”的配置:

配置项 它控制什么 什么时候需要认真看
stage3_module_granularity_threshold ZeRO-3 按模块做 gather/repartition 时的粒度阈值 模块层级复杂、host 开销高,或动态模块很多时
communication_data_type 通信路径采用什么 dtype 传输梯度/参数 多机多卡上 fp16/bf16 数值稳定性与通信带宽要一起平衡时
gradient_predivide_factor all-reduce 前的梯度预除因子 大规模并行训练里需要缓和梯度归约数值路径时
offload_param.buffer_count / buffer_size 参数 offload 时的缓冲池数量与块大小 CPU/NVMe offload 已经启用,但 GPU 在等 IO 或 host 内存抖动明显时
offload_optimizer.pipeline_read / pipeline_write 优化器状态的读写是否做流水化 NVMe offload 已经成主路径,希望减少读写阻塞时
aio.block_size / queue_depth / thread_count 异步 IO 管线的提交粒度与并发深度 NVMe offload 变成主要瓶颈,且你已经确认并非模型本身算力吃满时

这类参数不建议在“训练第一天”就满天飞地改。更稳的顺序是:先把 stage、micro-batch、梯度累积和混合精度跑稳,再针对真实瓶颈决定是去调通信、调 gather 粒度,还是调 offload 管线。

Checkpoint:保存、恢复与导出
Engine 级保存/恢复

DeepSpeedEngine 提供 save_checkpoint / load_checkpoint,用于保存与恢复模型、优化器、scheduler 以及自定义 client_state。工程要点是:所有 ranks 都必须调用 save_checkpoint,否则会在同步点 hang。ZeRO-3 下,保存后立刻在同一 engine 上 load(不重新初始化)是已知的不兼容用法。

Python
1
2
3
4
5
6
# 保存(所有进程都会参与)
model_engine.save_checkpoint("ckpt_dir", tag=f"global_step{global_step}", client_state={"step": global_step})
 
# 恢复(通常在初始化后尽早执行)
load_path, client_state = model_engine.load_checkpoint("ckpt_dir", tag=None)
global_step = client_state.get("step", 0)
ZeRO-3 恢复约束:先重建 engine,再 load

ZeRO-3 下,参数本来就是分片状态,因此 checkpoint 恢复并非“随手把文件再读回来”这么简单。工程上更稳的顺序是:

  1. 重新构造模型对象。
  2. 重新调用 deepspeed.initialize 得到新的 engine。
  3. 在这个新 engine 上调用 load_checkpoint。

不要把“刚 save 完的老 engine”直接拿来马上 load 同一路径,尤其在 ZeRO-3 下,这会把参数分片与内部状态管理搞得非常脆弱。恢复语义应被理解为“用 checkpoint 重新装配一套训练状态”,而非“对当前 engine 就地回滚”。

ZeRO checkpoint 权重导出(fp32 合并)

ZeRO-2/3 的 checkpoint 是分片形态。需要“脱离 DeepSpeed 继续使用/分享权重”时,常用做法是把 ZeRO checkpoint 转为合并后的 fp32 state_dict。

Python
1
2
3
4
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
 
fp32_state_dict = get_fp32_state_dict_from_zero_checkpoint("ckpt_dir", tag=None)
torch.save(fp32_state_dict, "pytorch_model_fp32.bin")
Universal Checkpoint:把分片 checkpoint 变成可迁移制品

当训练拓扑、并行策略或下游恢复环境会变化时,只导出一份 fp32 权重往往不够,因为你还可能需要恢复优化器状态、调度器状态以及更完整的训练上下文。DeepSpeed 近年的 Universal Checkpoint 路线,目标就是把原本强依赖当前 ZeRO/并行拓扑的 checkpoint 转成更容易跨环境迁移的格式。

Shell
1
2
3
python -m deepspeed.checkpoint.ds_to_universal \
  --input_folder /path/to/ds_ckpt/global_step1000 \
  --output_folder /path/to/universal_ckpt/global_step1000_uni

它更像“交付格式转换”而非训练时的主存储格式:训练阶段继续保存原生 DeepSpeed checkpoint,真正需要迁移、共享或给别的恢复流程消费时,再做 Universal 转换会更稳。

latest、latest_universal 与 load_universal 的关系

这组概念如果不分清,恢复逻辑几乎一定会写错。

  • 普通 save_checkpoint(..., save_latest=True) 写出的“最新 tag 指针”是 latest。
  • 启用 Universal Checkpoint 恢复链路后,DeepSpeed 会去看 latest_universal。这个文件通常来自转换流程,而非普通保存时自动生成。
  • checkpoint.load_universal=true 的含义,是“恢复时按 universal 语义查目录与 tag”,并非“保存时自动帮你多产一份 universal”。

因此,训练主路径通常仍保存原生 DeepSpeed checkpoint;真正要跨拓扑迁移、跨环境恢复,才做 ds_to_universal 转换,并补上 latest_universal 这一层索引。

16-bit 导出与 tag 校验的两个常见坑
  • 在 ZeRO-3 下, save_16bit_model() 只有在相应 gather 保存开关打开时,才有机会产出真正可用的 16-bit 单体权重。否则你以为拿到了导出,实际只得到不完整状态。
  • checkpoint.tag_validation 决定 DeepSpeed 在各 rank 的 checkpoint tag 不一致时是忽略、告警还是直接失败。多阶段脚本、手工拼装 tag、或并行保存逻辑复杂时,建议把这层校验看成“帮你提前暴露一致性错误”的安全带,而非烦人的额外检查。
与 Transformers 的集成
Trainer / TrainingArguments 接入

Transformers 的 Trainer 通过 TrainingArguments.deepspeed(或 CLI 的 --deepspeed)接入 DeepSpeed。工程上更稳定的做法是:把与 Trainer 重复的值在 ds_config 中写成 "auto",由 Trainer 统一灌入,避免“两边都写但不一致”。

ds_hf_auto.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "optimizer": {
    "type": "AdamW",
    "params": { "lr": "auto" }
  },
  "fp16": { "enabled": "auto" },
  "zero_optimization": { "stage": 2 }
}

Python
1
2
3
4
5
6
7
8
9
from transformers import TrainingArguments
args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    fp16=True,
    deepspeed="ds_hf_auto.json",
)
与 Accelerate 的集成
用 DeepSpeedPlugin(代码内指定)

Accelerate 提供 DeepSpeedPlugin,把 ZeRO stage、梯度累积等关键项绑定到 Accelerator 的生命周期里。工程要点是:DeepSpeed 需要提前知道 gradient_accumulation_steps,因此插件与训练循环要对齐,梯度累积本身仍需要按常规方式在代码里实现。

Python
1
2
3
4
from accelerate import Accelerator, DeepSpeedPlugin
 
ds_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
用 DeepSpeed 配置文件(更强的可控性)

当需要 ZeRO-3/offload/更细的 ZeRO knobs 时,Accelerate 通常通过“指定 DeepSpeed 配置文件”的方式接入。此时 ds_config.json 才是事实来源,代码侧只保留必要的 accelerator.prepare 与 accelerator.backward 语义,避免重复配置。

Shell
1
2
3
# 通过 accelerate config 生成运行配置后,再用 accelerate launch 运行训练脚本
accelerate config
accelerate launch train.py
RLHF 脚本链路:SFT / RM / PPO

DeepSpeed 在 RLHF 系统里很少单独出现,它通常与“多阶段脚本 + 多角色资源编排”一起工作。以 OpenRLHF / DeepSpeedExamples 这类工程为代表,最常见的是三段式链路:先做监督微调(SFT),再训练奖励模型(RM),最后进入 PPO 或近似 PPO 的在线策略优化。

阶段 1:SFT
Shell
1
2
3
4
5
6
7
deepspeed --module openrlhf.cli.train_sft \
  --model.model_name_or_path meta-llama/Meta-Llama-3-8B \
  --train.batch_size 256 \
  --train.micro_batch_size 2 \
  --ds.zero_stage 2 \
  --ds.param_dtype bf16 \
  --model.gradient_checkpointing_enable

SFT 阶段的目标是得到一个“能听懂指令、输出格式已基本对齐”的初始策略模型。这里的 DeepSpeed 角色主要是把单卡难以承受的 batch/序列长度压回可训练范围。

阶段 2:奖励模型(RM)
Shell
1
2
3
4
5
6
7
8
deepspeed --module openrlhf.cli.train_rm \
  --model.model_name_or_path OpenRLHF/Llama-3-8b-sft-mixture \
  --train.batch_size 256 \
  --train.micro_batch_size 2 \
  --ds.zero_stage 3 \
  --ds.param_dtype bf16 \
  --ds.packing_samples \
  --model.gradient_checkpointing_enable

RM 阶段的关键是打分。它要求 reward model 的 checkpoint 之后能被 PPO 阶段稳定加载,因此“checkpoint 目录结构、ZeRO stage、是否做 Universal 转换”最好在这一阶段就固定下来。

阶段 3:PPO / 在线策略优化
Shell
1
2
3
4
5
6
7
8
9
10
11
python3 -m openrlhf.cli.train_ppo_ray \
  --actor.num_gpus_per_node 8 \
  --critic.num_gpus_per_node 8 \
  --ref.num_gpus_per_node 8 \
  --reward.num_gpus_per_node 8 \
  --vllm.num_engines 4 \
  --vllm.tensor_parallel_size 2 \
  --train.colocate_all \
  --ds.zero_stage 3 \
  --ds.packing_samples \
  --train.dynamic_batch_enable

PPO 阶段里,DeepSpeed 已经从“包一个模型训练”扩展到 actor、critic、reference policy、reward model 这些角色各自带着自己的 ZeRO/显存策略运行,再通过 Ray 与 vLLM 协同。这就是为什么在线 RLHF 系统的复杂度远高于普通 SFT:训练本体、生成服务和奖励打分已经是三类不同的运行时。

PPO checkpoint 的后处理

PPO 往往会同时产出 actor 与 critic 的 DeepSpeed checkpoint 目录。若后续要跨环境迁移、给别的脚本恢复或归档,实践里常把 ZeRO checkpoint 进一步转换为 Universal 格式,而非把分片目录原样交给别的系统猜。

Shell
1
2
# 这类脚本通常会同时处理 actor / critic 两棵目录。
bash examples/scripts/ckpt_ds_zero_to_universal.sh /path/to/ppo_ckpt_root
vLLM 详解

vLLM 是面向服务化推理的运行时:围绕高吞吐调度、KV cache 管理、continuous batching、分布式并行与 OpenAI-compatible API 提供一体化推理栈。工程落地时可以把 vLLM 当作三条“入口路径”:离线推理的 LLM,可嵌入自建服务的 Engine,以及直接上线的 vllm serve。

安装路径与环境兼容

vLLM 的 wheel 包含大量编译好的 C++/GPU kernels。性能与兼容性高度依赖“vLLM wheel、PyTorch、驱动/运行时”三者的组合,工程上优先使用官方提供的预构建 wheel 或官方 Docker 镜像。

GPU 安装(推荐路径)

官方文档建议在新环境中安装 vLLM,并优先使用 wheel 自带的 PyTorch/依赖组合以减少二进制不兼容问题;此外,conda 安装的 PyTorch 可能静态链接 NCCL,容易在分布式/多进程场景引发问题。

Install vLLM (GPU): create env & install (pattern)
Shell
1
2
3
4
5
6
# create a clean env (example with uv)
uv venv --python 3.12 --seed
source .venv/bin/activate
 
# install vLLM
uv pip install vllm

基本自检可以用“导入 + 小模型离线生成”验证:

Sanity check: offline generate
Python
1
2
3
4
5
6
7
from vllm import LLM, SamplingParams
 
# enforce_eager=True 牺牲一部分吞吐,换取“先确认环境能跑通”的更稳定起点。
llm = LLM(model="facebook/opt-125m", enforce_eager=True)
params = SamplingParams(max_tokens=16, temperature=0.0)
out = llm.generate(["Hello, my name is"], params)
print(out[0].outputs[0].text)
Docker 安装(生产最常用)

生产系统更常直接使用官方镜像运行 OpenAI-compatible server。多进程与张量并行依赖共享内存,容器启动通常需要 --ipc=host 或显式配置 --shm-size。

vLLM official Docker image: run OpenAI-compatible server
Shell
1
2
3
4
5
6
7
8
9
10
11
# 把受限模型访问令牌注入容器环境,避免服务首次拉权重时出现 401。
export HF_TOKEN="<secret>"
 
# 共享 Hugging Face 缓存目录,减少容器重启后的重复下载。
docker run --gpus all \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  --env "HF_TOKEN=$HF_TOKEN" \
  -p 8000:8000 \
  --ipc=host \
  vllm/vllm-openai:latest \
  --model Qwen/Qwen3-0.6B

当宿主机驱动较旧时,官方镜像提供 CUDA compatibility 模式(只覆盖部分专业/数据中心 GPU 的兼容场景):

Docker: enable CUDA compatibility libraries (pattern)
Shell
1
2
3
4
5
6
7
8
9
# 只在驱动偏旧、且官方文档明确支持兼容库的机器上启用这条路径。
docker run --gpus all \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -p 8000:8000 \
  --env "HF_TOKEN=$HF_TOKEN" \
  --env "VLLM_ENABLE_CUDA_COMPATIBILITY=1" \
  --ipc=host \
  vllm/vllm-openai:latest \
  --model Qwen/Qwen3-0.6B
何时需要从源码构建

当你的 CUDA/ROCm 版本、PyTorch 构建配置或硬件平台与官方 wheel 不匹配时,需要从源码构建。官方提供了以 VLLM_USE_PRECOMPILED=1 作为起点的可编辑安装,以及基于 CMake 的增量编译工作流用于迭代 kernels。

Build-from-source (editable) + incremental build toolchain (pattern)
Shell
1
2
3
4
5
6
7
8
9
10
11
git clone https://github.com/vllm-project/vllm.git
cd vllm
 
uv venv --python 3.12 --seed
source .venv/bin/activate
 
# 先做 editable 安装,让 Python 侧改动可以直接生效。
VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto
 
# 只有需要改 C++/CUDA kernel 时才补这组构建依赖。
uv pip install -r requirements/build.txt --torch-backend=auto
三条接口:LLM / Engine / vllm serve

vLLM 的 API 结构可以用“离线批处理推理”“可嵌入的引擎”“生产服务端”三条路径来理解。三者底层共享同一套 Engine 配置(EngineArgs),差异在于请求进入方式与生命周期管理。

接口 1:LLM(离线批处理推理)

vllm.LLM 适合离线批处理与数据集推理。它接受 prompts 列表并返回结构化输出,常用于离线评测、数据合成与批量生成。

LLM: batched offline inference
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
    "The capital of France is",
]
params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=64)
 
# 离线批处理仍然通过 LLM 对象统一接入;底层 engine 和调度器由 vLLM 内部管理。
llm = LLM(model="facebook/opt-125m")
outputs = llm.generate(prompts, params)
for o in outputs:
    print(o.prompt)
    print(o.outputs[0].text)

采样参数的默认来源有两套:模型仓库里的 generation_config 与 vLLM 自己的默认值。若业务希望显式使用 vLLM 的默认采样参数,可以在创建 LLM 时设置:

LLM: control generation_config source (pattern)
Python
1
2
3
4
5
6
7
from vllm import LLM
 
llm = LLM(
    model="facebook/opt-125m",
    # 显式忽略模型仓库里的 generation_config,回到 vLLM 自己的默认采样语义。
    generation_config="vllm",
)
SamplingParams:离线与服务共用的采样控制面

SamplingParams 是 vLLM 里最常被反复创建的对象之一。它对应的是“单次请求想怎么解码”,而非整个 engine 的资源配置;因此它更接近业务请求参数,而非部署参数。

命令/API/函数
temperature / top_p / top_k

说明
控制随机性与候选截断范围。适合把“输出多样性”从服务默认值里拆成请求级旋钮。

示例

Python
1
2
3
4
5
6
7
8
from vllm import SamplingParams
 
params = SamplingParams(
    temperature=0.7,  # 降低随机性,但保留一定表达变化
    top_p=0.9,        # 切掉长尾 token,减少离谱采样
    # 再给候选集合一个显式上界,防止 nucleus 后候选仍过宽。
    top_k=50,
)

命令/API/函数
max_tokens / stop / stop_token_ids

说明
限制回复长度并声明何时停止。对工具调用、结构化输出和 Web 对话都很常见。

示例

Python
1
2
3
4
5
6
params = SamplingParams(
    max_tokens=256,                  # 给单次回答设置上限,避免尾部失控
    stop=["\nUser:"],                # 遇到特定分隔符就截断
    # 当上游协议或模板已经约定了特殊 token 时,按 token id 截断更稳。
    stop_token_ids=[151645],
)

命令/API/函数
n / best_of

说明
控制一次请求要生成多少个候选,以及内部要采多少条再返回最好的一批。离线数据合成与 reranking 场景很常用。

示例

Python
1
2
3
4
5
params = SamplingParams(
    n=4,          # 返回 4 个候选,供后续 rerank 或规则过滤
    # 先在内部采 8 条,再把得分更好的 4 条返回给业务侧。
    best_of=8,
)
接口 2:Engine(嵌入式引擎,用于自建服务)

Engine 路线用于把 vLLM 嵌入到自建服务/作业系统中,获得“请求级 streaming + 细粒度生命周期控制”。当前主线接口是 V1 Engine( AsyncLLM),通过 AsyncEngineArgs 构建。

Engine: AsyncLLM streaming generate (pattern)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import asyncio
 
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
 
async def main() -> None:
    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-3.2-1B-Instruct",
        # 先用 eager 路线做服务集成,减少 graph capture 或编译问题带来的噪声。
        enforce_eager=True,
    )
    engine = AsyncLLM.from_engine_args(engine_args)
    try:
        params = SamplingParams(
            max_tokens=64,
            temperature=0.2,
            # DELTA 模式每轮只返回新增 token,最适合直连 SSE / WebSocket 流。
            output_kind=RequestOutputKind.DELTA,
        )
        async for out in engine.generate(
            request_id="req-1",
            prompt="Write a haiku about caching.",
            sampling_params=params,
        ):
            for c in out.outputs:
                if c.text:
                    print(c.text, end="", flush=True)
            if out.finished:
                break
    finally:
        # engine 内部持有调度线程和 GPU 资源;嵌入式服务退出前要显式 shutdown。
        engine.shutdown()
if __name__ == "__main__":
    asyncio.run(main())

在 Engine 路线里,“并发/显存预算”通常通过 EngineArgs 控制,应用侧需要自行处理:请求队列、超时/取消(abort)、重试、以及与外部网关的对接。

AsyncLLM 的生命周期控制

把 vLLM 当嵌入式引擎使用时,真正的工程难点是请求取消、更新窗口和进程退出是否可控。 AsyncLLM 这一层已经把这些动作做成显式接口。

AsyncLLM lifecycle-safe pattern
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import asyncio
from contextlib import suppress
 
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
 
 
async def run_one(engine: AsyncLLM, request_id: str, prompt: str) -> str | None:
    params = SamplingParams(
        # 这里只让模型最多生成 128 个 token,目的是把例子收敛到“请求控制”而非采样调参。
        max_tokens=128,
        # temperature=0.0 把例子固定在确定性更强的解码路径上,便于观察请求生命周期。
        temperature=0.0,
    )
    try:
        async for out in engine.generate(
            request_id=request_id,   # request_id 是后续 abort / 追踪 / 日志关联的主键。
            prompt=prompt,           # 这里直接传 prompt 字符串;真实系统也可以改成 messages 路径。
            sampling_params=params,  # 采样策略对象和请求一起提交给 engine。
        ):
            if out.finished:
                return out.outputs[0].text  # 只取第一条候选,保持示例焦点在生命周期控制。
    except asyncio.CancelledError:
        # 请求级取消要显式通知 engine 回收对应状态。
        await engine.abort(request_id)
        raise
    return None
 
 
async def main() -> None:
    engine = AsyncLLM.from_engine_args(
        AsyncEngineArgs(
            model="facebook/opt-125m",  # 选小模型只是为了让生命周期示例更容易本地复现。
            enforce_eager=True,         # 关闭更激进的图优化路径,减少集成期额外变量。
        )
    )
    try:
        # create_task 把一个生成请求交给 event loop;真实服务里这里通常对应一个用户请求。
        task = asyncio.create_task(run_one(engine, "req-1", "Explain request schedulers in one paragraph."))
 
        # 超时并不同于 engine 已经自动丢弃请求;超时处理要和 abort 配合思考。
        with suppress(asyncio.TimeoutError):
            print(await asyncio.wait_for(task, timeout=5))
 
        # 在线更新或切换窗口前,可以先暂停新生成。
        await engine.pause_generation(mode="keep", clear_cache=True)
        await engine.resume_generation()
    finally:
        # 服务退出前要显式 shutdown,让后台线程和 GPU 资源有序释放。
        engine.shutdown()
 
 
if __name__ == "__main__":
    asyncio.run(main())

abort() 控制请求级取消, pause_generation() / resume_generation() 控制更新窗口, shutdown() 控制进程级收尾。对自建服务而言,这些接口的价值远高于“再包一层 HTTP 就能上线”。

接口 3:vllm serve(生产服务端)

vllm serve 直接启动 OpenAI-compatible server,是最接近“拿来就用”的生产入口。典型端点包括 /v1/chat/completions 与 /v1/embeddings,并支持流式输出(SSE)。

vllm serve: start OpenAI-compatible server
Shell
1
2
3
4
5
6
# 这是最常见的生产入口:直接暴露 OpenAI-compatible HTTP 服务。
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --host 0.0.0.0 \
  --port 8000 \
  --dtype auto \
  --api-key token-abc123

OpenAI-compatible request (generic curl)
Shell
1
2
3
4
5
6
7
8
9
# curl 只验证协议层是否通,不代表服务已经调到最佳吞吐。
curl http://localhost:8000/v1/chat/completions \
  -H 'Content-Type: application/json' \
  -H 'Authorization: Bearer token-abc123' \
  -d '{
    "model": "meta-llama/Meta-Llama-3-8B-Instruct",
    "messages": [{"role":"user","content":"Hello!"}],
    "stream": true
  }'
OpenAI Python SDK(base_url 与 extra_body)

应用侧最常见的接入方式是 OpenAI Python SDK,把 base_url 指向自托管服务。部分参数在 OpenAI API 中不存在,但 vLLM 支持;这类扩展字段通常通过 extra_body 传入。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
vLLM (pattern)">
from openai import OpenAI
 
# 通过 base_url 把官方 OpenAI SDK 指向自建 vLLM 服务。
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
 
# 普通 OpenAI 字段直接按标准传;vLLM 扩展字段通过 extra_body 透传。
resp = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0.2,
    max_tokens=128,
    extra_body={"top_k": 50},
)
# 这里返回的是完整的 ChatCompletion 对象,业务代码可以继续读 usage、finish_reason 等字段。
print(resp.choices[0].message)
OpenAI-compatible 端点矩阵

生产里最容易出错的一点,是把“兼容 OpenAI”理解成“只支持 chat completions”。vLLM 的 HTTP 面远比这宽,是否暴露对应能力,要看你启动的模型与服务参数。

端点 工程用途 什么时候优先用它
/v1/chat/completions 最经典的聊天式生成接口 应用已经按 Chat Completions 组织 prompt,或要兼容大量现有 SDK/中间件
/v1/responses 更统一的新式接口,便于承载结构化输出、多模态与工具调用扩展 新系统直接建设,且希望减少未来从 Chat Completions 迁移的成本
/v1/embeddings / /v2/embed 向量化入口 做检索、重排前召回、聚类或语义缓存
/v1/rerank / /v2/rerank / /v1/score 重排与打分 检索系统里要把粗召回结果重新排序,或需要 pairwise/listwise 相关性分数
/tokenize / /detokenize token 级观测与调试 排查 prompt 模板、上下文预算、停词边界与计费口径
服务配置中最容易忽视的兼容开关
vLLM serve: compatibility-focused launch
Shell
1
2
3
4
5
6
7
8
9
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --host 0.0.0.0 \
  --port 8000 \
  --api-key token-abc123 \
  --generation-config vllm \
  --chat-template-content-format auto \
  --enable-request-id-headers \
  --max-num-batched-tokens 8192 \
  --max-num-seqs 256
  • --generation-config vllm 用来避免模型仓库里的 generation_config.json 静默覆盖服务端解码默认值。线上同一服务接多个业务方时,这个开关能显著减少“为什么同样 temperature,行为却不一致”的排障时间。
  • --chat-template-content-format 控制请求消息内容如何映射到 chat template。多模态或复杂 content 格式下,它直接决定模板渲染是否和客户端预期一致。
  • --enable-request-id-headers 让 X-Request-Id 能沿 HTTP 边界传递。服务一旦接入网关、APM 或异步任务系统,这个 request id 往往比 prompt 本身更关键,因为它是跨层排障的唯一稳定主键。
Tool calling 与 structured outputs 的生产语义

vLLM 支持工具调用,但不同模式的可靠性差异很大。生产里需要把“方便演示”和“可验收约束”分开。

  • --enable-auto-tool-choice 并非独立开关,它需要同时提供 --tool-call-parser。前者允许模型自动决定是否调用工具,后者负责把模型输出解析回工具调用结构。
  • tool_choice="auto" 更接近“模型自由输出 + 解析器尽量提取”。它适合探索式系统,但并不天然保证一定满足 schema。
  • tool_choice="required" 或显式指定工具名,更接近“必须产出一个符合工具调用壳子的结果”。这类模式更适合工作流系统与生产链路。
  • 请求里的 strict 字段常见于 OpenAI 风格客户端,但在不同版本组合下,它更多承担兼容入口角色,而非单独决定解码行为的神奇开关。真正约束输出的,还是 structured outputs 或明确的工具 schema。
Structured outputs via OpenAI SDK
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from openai import OpenAI
 
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
 
resp = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[{"role": "user", "content": "Classify: vLLM is production-ready."}],
    response_format={
        "type": "json_schema",
        "json_schema": {
            "name": "sentiment",
            "schema": {
                "type": "object",
                "properties": {
                    "label": {"type": "string", "enum": ["positive", "negative"]},
                },
                "required": ["label"],
                "additionalProperties": False,
            },
        },
    },
)
 
print(resp.choices[0].message.content)

新系统应优先围绕 structured_outputs 与标准 response_format 建设,而非继续押注旧的 guided_json、 guided_regex 之类历史字段。后者更多是兼容路径,前者才是现在的主语义接口。

服务端配置文件(YAML)

vllm serve 支持从 YAML 配置文件加载参数。参数名使用长参数形式(long form)。CLI 与配置文件同时提供时,优先级为 CLI > config > defaults。

vLLM serve config.yaml (pattern)
YAML
1
2
3
4
5
6
7
8
9
10
model: meta-llama/Llama-3.1-8B-Instruct
host: "0.0.0.0"
port: 8000
uvicorn-log-level: "info"
api-key: "token-abc123"
dtype: "auto"
max-model-len: 8192
gpu-memory-utilization: 0.90
max-num-seqs: 64
enable-prefix-caching: true

vLLM serve: launch with config
Shell
1
vllm serve --config config.yaml
在线 RLHF:权重热更新与 Prefill/Decode 解耦

一旦把 vLLM 用到在线 RLHF 或异步后训练中,服务端就不再只是“提供推理 API”,还需要和训练进程交换新权重、暂停生成、完成热切换。vLLM 已经把这类能力做成显式参数和 HTTP 接口,而非要求用户每轮都重启整个服务。

weight-transfer-config 与热更新端点

--weight-transfer-config 负责打开 trainer ↔ serving 的权重同步通道。典型流程是:训练侧先请求初始化权重传输引擎,再开始更新、逐块推送新权重,最后通知服务端切换到新版本。服务端通常还会配合 /pause、 /resume 暂停和恢复新请求生成。

vLLM: serve with weight transfer (pattern)
Shell
1
2
3
4
5
VLLM_SERVER_DEV_MODE=1 \
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --host 0.0.0.0 \
  --port 8000 \
  --weight-transfer-config '{"backend":"nccl","engine":"v1"}'

Weight transfer endpoints (concept)
1
2
3
4
5
6
POST /init_weight_transfer_engine
POST /start_weight_update
POST /update_weights
POST /finish_weight_update
POST /pause
POST /resume

Trainer side: drive vLLM weight update over HTTP + NCCL
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import requests
from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLTrainerSendWeightsArgs,
    NCCLWeightTransferEngine,
)
 
# base 指向 vLLM 服务的控制面地址;后面所有 pause / resume / update 都走这里。
base = "http://127.0.0.1:8000"
 
# 第一步是控制面初始化;HTTP 负载里传的是 init_info,而非裸 backend 字段。
requests.post(
    f"{base}/init_weight_transfer_engine",
    json={
        "init_info": {
            "master_address": "10.0.0.1",  # NCCL 通信主节点地址;trainer 和 serving 需要都能访问到。
            "master_port": 29501,          # NCCL 建链端口;要和训练侧后续 trainer_init 保持一致。
            "rank_offset": 1,              # 让 serving ranks 与 trainer ranks 的编号区间不互相撞车。
            "world_size": 3,               # 整个 weight-transfer 通信域里的总 rank 数。
        }
    },
    timeout=60,
).raise_for_status()
 
# 进入更新窗口前先暂停生成,让新老权重切换有明确边界。
requests.post(f"{base}/pause", params={"mode": "keep"}, timeout=60).raise_for_status()
requests.post(
    f"{base}/start_weight_update",
    # is_checkpoint_format=True 表示后续发送的是“按 checkpoint 语义组织”的权重块元信息。
    json={"is_checkpoint_format": True},
    timeout=60,
).raise_for_status()
 
# HTTP 这里传的是本轮权重块的元信息,而非直接把大张量塞进 JSON。
meta = {
    "names": names,              # 张量名列表;服务端靠它知道接下来写回哪些参数。
    "dtype_names": dtype_names,  # 每个张量对应的数据类型字符串,供接收端恢复 tensor 解释方式。
    "shapes": shapes,            # 每个张量的形状;否则服务端无法重建参数布局。
    "packed": True,              # 声明 trainer 发送的是打包权重流,而非逐 tensor 独立发送。
}
requests.post(
    f"{base}/update_weights",
    json={"update_info": meta},
    timeout=300,
).raise_for_status()
 
# 真正的权重数据平面走 NCCL;HTTP 只负责控制顺序和元信息。
group = NCCLWeightTransferEngine.trainer_init(
    {
        "master_address": "10.0.0.1",  # 与 init_info 保持一致,双方才能进入同一 NCCL 通信域。
        "master_port": 29501,
        "world_size": 3,
    }
)
NCCLWeightTransferEngine.trainer_send_weights(
    iterator=model.named_parameters(),  # 把训练中当前模型参数按名字迭代出来,作为实际发送的数据源。
    # packed=True 要和上面的 update_info 保持一致,否则服务端解析权重流的方式会错位。
    trainer_args=NCCLTrainerSendWeightsArgs(group=group, packed=True),
)
 
# finish_weight_update 表示本轮新权重已经完整送达,服务端可以切换到新版本。
requests.post(f"{base}/finish_weight_update", json={}, timeout=60).raise_for_status()
 
# 最后恢复生成,让后续新请求开始吃到新权重。
requests.post(f"{base}/resume", timeout=60).raise_for_status()

这里有两个边界必须说清楚。第一, init_weight_transfer_engine 与 update_weights 的 JSON 负载结构分别是 {"init_info": ...} 与 {"update_info": ...},并非任意自造键名。第二,HTTP 是控制面,NCCL 或 IPC 才是数据面;把大张量直接塞进 HTTP JSON,不符合这条接口的设计方式。

这类端点通常只在开发模式下开放,因此示例命令把 VLLM_SERVER_DEV_MODE=1 显式写在前面。线上暴露时必须额外套网关与访问控制,因为很多非 /v1* 端点并不属于普通 API 消费面。

kv-transfer-config 与 Prefill/Decode 解耦

长上下文服务里,prefill 和 decode 的负载形态差异很大。vLLM 的 --kv-transfer-config 允许把 prefill 产出的 KV 传给 decode 侧实例,从而把两类负载拆到不同进程甚至不同机器上。这类拓扑特别适合“长 prompt + 短回复”的系统,因为 prefill 往往比 decode 更吃带宽和显存。

vLLM: disaggregated prefill/decode (pattern)
Shell
1
2
3
4
5
6
7
8
9
# prefill producer
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --port 8100 \
  --kv-transfer-config '{"role":"producer","connector":"shared_storage"}'
 
# decode consumer
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --port 8200 \
  --kv-transfer-config '{"role":"consumer","connector":"shared_storage"}'

Prefill/Decode 解耦改善的通常是资源形态匹配与尾延迟,而非“任何负载下都绝对更快”。如果请求大多是短 prompt、短回复,额外的 KV 传输与系统复杂度可能抵消收益。

EngineArgs(核心配置面)

Engine arguments 控制 vLLM 的运行行为:离线推理时它们是 LLM(...) 的一部分参数;在线服务时它们是 vllm serve 的参数子集。工程上可以把 EngineArgs 按职责分为四类:模型与 tokenizer、并行与执行器、KV cache 与调度、以及安全/可观测性。

常用 EngineArgs(服务端视角)
参数 含义 工程后果
--max-model-len 最大上下文长度 直接决定 KV cache 的 token 预算;过大常导致并发下降或 OOM
--gpu-memory-utilization 显存预算比例 留出系统/碎片空间;过高会提升 OOM 风险
--max-num-batched-tokens 每步调度的 token 预算 影响吞吐与尾延迟,常与并发/显存一起调
--max-num-seqs 并发序列数上限 决定同卡并发,过高会导致排队抖动与尾延迟上升
--kv-cache-dtype KV cache 存储精度 影响显存/带宽;更激进精度需要评估质量与稳定性
--trust-remote-code 允许加载模型仓库的自定义代码 改变执行边界;仅在可信模型源启用
--download-dir 权重/缓存下载目录 容器化时用于挂载共享缓存,减少冷启动成本
EngineArgs 在 Python 侧的用法

当你希望把“服务端配置”复用到离线任务中,可以在 Python 里用 EngineArgs 组装配置,再把字段展开给 LLM:

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
LLM (pattern)">
from dataclasses import asdict
 
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
 
# 先把服务端同款参数集中到 EngineArgs,便于在离线脚本和在线服务之间复用配置。
engine_args = EngineArgs(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    max_model_len=8192,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)
 
# asdict 会把 dataclass 展开成 LLM 构造函数可接受的关键字参数。
llm = LLM(**asdict(engine_args))
吞吐预算:max_num_batched_tokens 与 chunked prefill

vLLM 服务调优里,最常被忽视的核心是调度器预算。 max_num_batched_tokens 控制单个调度步最多同时处理多少 token; enable_chunked_prefill 控制长 prompt 的 prefill 是否允许被切块处理。二者决定的是服务端如何在“长提示词请求”和“短请求延迟”之间做平衡。

EngineArgs: scheduler budget (pattern)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from dataclasses import asdict
 
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
 
engine_args = EngineArgs(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    max_model_len=16384,
    # GPU 显存预算不要一上来顶满;先留出 KV cache 波动和碎片化余量。
    gpu_memory_utilization=0.90,
    # 单个调度步最多处理 4096 个 token;
    # 它太小会压吞吐,太大又会让长请求拖慢短请求。
    max_num_batched_tokens=4096,
    # 打开后,超长 prompt 的 prefill 可以分块进入调度器,
    # 更适合“长输入 + 高频并发”的服务。
    enable_chunked_prefill=True,
    # max_num_seqs 限制同一时刻允许并发挂在调度器上的请求条数。
    max_num_seqs=64,
)
 
llm = LLM(**asdict(engine_args))

经验上, enable_chunked_prefill=True 更适合长上下文服务;而 max_num_batched_tokens 往往需要结合真实流量压测来找平衡点。它并非“越大越快”的旋钮,因为过大的 batch token 预算会拉高单步时延,并放大长请求对短请求的阻塞。

调度器的第二层旋钮:partial prefill、stream interval 与调度策略

只调 max_num_batched_tokens 往往不够。长上下文、高并发和在线流式三者同时存在时,调度器的第二层参数才是真正决定尾延迟的部分。

参数 控制什么 什么时候值得调
--scheduling-policy 请求在调度器里的优先策略 不同业务混跑,且你确实需要在公平性和吞吐之间做取舍时
--stream-interval 流式输出向客户端刷新的频率 前端强调实时感知,或日志系统希望减少碎片化 token 事件时
--max-num-partial-prefills 允许多少个长 prompt 以切块形式并行进入 prefill 长 prompt 请求经常把短请求拖住时
--max-long-partial-prefills / --long-prefill-token-threshold 定义“多长算长请求”,以及这类请求最多允许多少并发切块 流量同时包含极长上下文任务与常规问答请求时

这一层参数没有固定最优值。它们依赖模型大小、GPU 代际、上下文长度分布,以及你更关心吞吐、TTFT 还是尾延迟。文档里给的是旋钮,真正落地仍然要靠业务流量压测。

批量推理系统里的 vLLM:Processor 配置而非手搓循环

在真实离线批量推理系统里,vLLM 往往会挂进更高层的数据处理框架,由框架负责分片、重试、结果回写和集群扩缩。Ray Data 的 vLLM Processor 就是一个典型模式:把 prompt 构造、采样参数和结果抽取分成三段函数,vLLM 只负责把 GPU 算力吃满。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
 
config = vLLMEngineProcessorConfig(
    model_source="unsloth/Llama-3.1-8B-Instruct",
    engine_kwargs={
        # 这两项直接下发给 vLLM engine,用来约束长 prompt 的调度方式。
        "enable_chunked_prefill": True,
        "max_num_batched_tokens": 4096,
        "max_model_len": 16384,
    },
    concurrency=1,   # 每个 Processor 副本背后一份 vLLM engine;并发副本数由这里控制
    batch_size=64,   # 上游数据框架把多少条记录聚成一批交给 vLLM
)
 
processor = build_llm_processor(
    config,
    # preprocess 负责把业务行转成 messages + sampling_params。
    preprocess=lambda row: dict(
        messages=[{"role": "user", "content": row["text"]}],
        sampling_params={"temperature": 0.3, "max_tokens": 250},
    ),
    # postprocess 负责从生成结果里抽回业务字段,便于直接写回 Parquet/对象存储。
    postprocess=lambda row: dict(answer=row["generated_text"], **row),
)

这条路线的意义在于职责分离:数据系统负责大规模分片与回写,vLLM 负责单副本高吞吐推理。两者分开后,离线推理任务就不需要把“数据调度”和“GPU 推理调度”揉在同一个脚本里。

并行:Tensor Parallel / Pipeline Parallel / Data Parallel

分布式推理的目标是“把单模型副本放进足够多的 GPU 里,并把负载分摊出去”。vLLM 的并行策略可以分为三类:单副本的张量并行/流水并行,以及多副本的 data parallel(权重复制)。

Tensor Parallel(单机多卡)

当模型无法放进单卡但能放进单机多卡时,设置 tensor_parallel_size 为“每节点 GPU 数”是最常见策略。

vllm serve: tensor parallel (single node)
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --tensor-parallel-size 4 \
  --host 0.0.0.0 --port 8000
Pipeline Parallel(多机或单机不均匀切分)

当模型超过单机容量,需要组合张量并行与流水并行:把 tensor_parallel_size 设为“每节点 GPU 数”,把 pipeline_parallel_size 设为“节点数”。如果模型在单机可容纳,但 GPU 数无法均匀切分模型,也可以用 pipeline parallel 做不均匀切分:此时常见设置是 tensor_parallel_size=1, pipeline_parallel_size=GPU 数。

vllm serve: tensor+pipeline parallel (pattern)
Shell
1
2
3
4
vllm serve meta-llama/Meta-Llama-3-70B-Instruct \
  --tensor-parallel-size 8 \
  --pipeline-parallel-size 2 \
  --host 0.0.0.0 --port 8000
Data Parallel(多副本 + 负载均衡)

data parallel 复制权重,让多个 GPU/进程独立处理请求,适合吞吐扩展。vLLM 支持“自包含 DP(一个对外端点,内部做 rank 级负载均衡)”与“外部负载均衡(每 rank 单独对外,外部 LB 路由)”。

vllm serve: data parallel (self-contained, pattern)
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --data-parallel-size 4 \
  --host 0.0.0.0 --port 8000
前缀缓存(Prefix Caching)

前缀缓存缓存“已 prefill 的前缀 KV blocks”,新请求与历史请求共享前缀时,可以复用缓存并跳过重复的 prefill 计算。它对“系统 prompt 固定、RAG 模板固定、长上下文重复”的场景收益很大。

开启方式与哈希策略

服务端通过 --enable-prefix-caching 开启前缀缓存。为多租户隔离与碰撞风险控制,前缀缓存提供可配置的哈希策略(例如使用 SHA256 族)。

vllm serve: enable prefix caching (pattern)
Shell
1
2
3
4
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --enable-prefix-caching \
  --prefix-caching-hash-algo sha256 \
  --host 0.0.0.0 --port 8000
Speculative Decoding

speculative decoding 用“草稿模型提出多个候选 token + 目标模型验证并接收其中一部分”的方式减少目标模型 decode 步数,从而降低解码延迟。服务端需要同时加载目标模型与草稿模型,并为草稿模型设置独立的资源预算。

vllm serve: speculative decoding (CLI pattern)
Shell
1
2
3
4
5
6
7
# 启动 vLLM 服务。
vllm serve <target-model> \
  --speculative-config '{
    "method": "draft_model",
    "model": "<draft-model>",
    "num_speculative_tokens": 5
  }'

speculative decoding 与某些并行策略(例如 pipeline parallel)可能存在兼容性限制,上线前需要在目标版本组合上做压测与回归。

监控、日志与部署注意项
/metrics 与 Prometheus

vLLM 的 OpenAI-compatible server 默认暴露 /metrics,可用于 Prometheus 抓取与容量规划。

Query /metrics
Shell
1
curl http://0.0.0.0:8000/metrics

容量规划时需要重点关注两类信息:KV cache 的 token 容量与“最大并发估计”。vLLM 启动日志通常会输出类似的估算信息(示例格式如下):

Example: concurrency estimate lines (concept)
1
2
GPU KV cache size: 643,232 tokens
Maximum concurrency for 40,960 tokens per request: 15.70x
健康检查与日志降噪

服务端通常提供健康检查端点(例如 /health、 /ping)。生产环境里这些端点会被 LB 高频调用,建议通过 --disable-access-log-for-endpoints 关闭对应 access logs,避免淹没有效日志:

Disable access logs for noisy endpoints
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --disable-access-log-for-endpoints "/health,/metrics,/ping" \
  --host 0.0.0.0 --port 8000
日志配置(环境变量)

vLLM 使用 Python 的 logging 配置体系,并提供环境变量控制默认日志行为。最常见的两类控制是:关闭 vLLM 的默认日志配置,以及提供自定义 JSON logging 配置文件路径。

Logging controls (pattern)
Shell
1
2
3
4
5
6
# 完全交给宿主应用自己的 logging 配置,适合已有统一日志体系的服务。
export VLLM_CONFIGURE_LOGGING=0
 
# 也可以继续让 vLLM 初始化 logging,但显式指定 JSON 配置文件。
export VLLM_CONFIGURE_LOGGING=1
export VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json
部署前的稳定性检查清单
  • 显存预算:用 --max-model-len 与 --gpu-memory-utilization 先跑通,再逐步提高 --max-num-seqs 与 --max-num-batched-tokens 做压测。
  • 默认行为:明确 chat template、tokenizer 与 generation_config 的来源与优先级,避免升级后默认采样参数变化。
  • 权限边界: --trust-remote-code 只在可信模型源启用;容器中通过只读挂载、最小权限与镜像固化降低风险。
  • 日志与指标:确保 /metrics 可被抓取,健康检查端点与 access log 策略不会引发噪声或误报警。
运维 CLI:bench 与 run-batch

vLLM 近年的 CLI 已经不只剩 serve。做容量评估、回归压测与离线批任务时, vllm bench 和 vllm run-batch 往往比自己写临时脚本更稳,因为它们直接复用了官方参数面与统计口径。

vLLM: operational CLIs
Shell
1
2
3
4
5
6
# 吞吐/延迟基准
vllm bench throughput --model meta-llama/Meta-Llama-3-8B-Instruct
vllm bench latency --model meta-llama/Meta-Llama-3-8B-Instruct
 
# 离线批任务
vllm run-batch --input-file prompts.jsonl --output-file outputs.jsonl
知名代码精读

AI 训练与推理编程的学习不能只停在 API 表面。成熟项目的源码会暴露更真实的问题:张量形状如何流动,工具接口如何约束模型,权限和上下文如何管理,NER 模型如何把标签、边界和 span 组织成可训练目标。本章选取三类代码做精读:手写 Transformer、Claude Code 类 AI 编程 agent、以及 CRF / GlobalPointer / GLiNER 等 NER 知名算法。

手写Transformers

手写 Transformer 精读应覆盖一条完整工程链路:配置如何落到模型形状,权重如何初始化和共享,训练数据如何切成输入/标签,训练循环如何处理 AMP 与梯度累积,学习率如何变化,推理如何自回归生成,checkpoint 如何恢复,以及 minGPT 到 nanoGPT 的演进到底补齐了哪些工程能力。下面的代码按 nanoGPT / minGPT 的公开实现抽象改写,保留核心结构,并把每一行的工程意义写在旁边或上一行。

配置

配置层负责把“模型多大、上下文多长、训练多久、用什么 dtype”变成可记录、可覆盖、可写入 checkpoint 的结构。nanoGPT 的配置方式很朴素:默认值写在脚本顶部,再允许命令行或配置文件覆盖。这种写法适合教学和小型实验,也能清楚展示每个超参进入哪条路径。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# dataclass 用来保存模型结构参数,checkpoint 里也会写入这组值。
@dataclass
class GPTConfig:
    # block_size 是模型一次最多能看的 token 数,也就是上下文窗口长度。
    block_size: int = 1024
    # vocab_size 是输出分类数;50304 是把 GPT-2 50257 词表补齐到更利于硬件对齐的尺寸。
    vocab_size: int = 50304
    # n_layer 控制 Transformer block 堆叠层数,主要决定深度。
    n_layer: int = 12
    # n_head 控制多头注意力的头数,必须能整除 n_embd。
    n_head: int = 12
    # n_embd 是每个 token 的隐藏维度,主要决定模型宽度。
    n_embd: int = 768
    # dropout 训练时随机丢弃部分激活,预训练常设 0,微调可适当增大。
    dropout: float = 0.0
    # bias 控制 Linear 和 LayerNorm 是否带偏置;关掉可略微省参数和算力。
    bias: bool = True
 
 
# out_dir 是训练产物目录,checkpoint、日志和采样脚本都会依赖它。
out_dir = "out"
# eval_interval 决定多少个 optimizer step 做一次评估。
eval_interval = 2000
# eval_iters 决定评估时抽多少个 batch 来估计 train/val loss。
eval_iters = 200
# init_from 决定初始化来源:scratch、resume 或 GPT-2 权重。
init_from = "scratch"
# dataset 指向 data/<dataset>/train.bin 与 val.bin。
dataset = "openwebtext"
# gradient_accumulation_steps 用多个 micro batch 模拟更大的全局 batch。
gradient_accumulation_steps = 40
# batch_size 是单次前向的 micro batch size。
batch_size = 12
# block_size 要和模型 config 对齐,也决定每个样本的 token 长度。
block_size = 1024
# learning_rate 是 AdamW 的峰值学习率。
learning_rate = 6e-4
# max_iters 是训练主循环的最大 step 数。
max_iters = 600000
# weight_decay 只作用在需要衰减的大矩阵参数上。
weight_decay = 1e-1
# dtype 决定 autocast 和 GradScaler 的行为。
dtype = "bfloat16" if torch.cuda.is_bf16_supported() else "float16"
# compile 控制是否启用 torch.compile,通常训练前几步会有编译开销。
compile = True
 
 
# config_keys 从当前脚本全局变量里收集可覆盖项。
config_keys = [
    # 只收集公开名字,避免把导入模块、私有临时变量写入配置。
    k for k, v in globals().items()
    # 简单标量最适合命令行覆盖,也便于写入日志。
    if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
# configurator.py 会把命令行或配置文件中的值写回 globals()。
exec(open("configurator.py").read())
# config 是训练配置快照,后面会写入 checkpoint,便于复现实验。
config = {k: globals()[k] for k in config_keys}
权重初始化/共享

模型初始化会同时决定结构、参数尺度和推理路径。GPT 实现里至少有四个关键点:LayerNorm 可选 bias,token embedding 与输出头权重共享,残差投影做缩放初始化,推理时只计算最后一个位置的 logits。这些细节共同影响显存、训练稳定性和推理速度。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# LayerNorm 单独包一层,是为了支持 bias=False。
class LayerNorm(nn.Module):
    # ndim 是归一化的最后一维大小,bias 控制是否学习偏置。
    def __init__(self, ndim: int, bias: bool) -> None:
        # 初始化 nn.Module 基类,注册参数前必须调用。
        super().__init__()
        # weight 是缩放参数,初值为 1 表示先不改变归一化后的尺度。
        self.weight = nn.Parameter(torch.ones(ndim))
        # bias 为 False 时不创建偏置参数,节省参数并贴近部分现代 LLM 设置。
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
 
    # input 形状通常是 [B, T, C],LayerNorm 只归一化最后一维 C。
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # F.layer_norm 直接接受可选 bias,比 nn.LayerNorm 更容易控制参数存在性。
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
 
 
# GPT 主类负责把 embedding、block 堆叠、最终归一化和输出头组装起来。
class GPT(nn.Module):
    # config 里包含 block_size、vocab_size、层数、头数、宽度等结构参数。
    def __init__(self, config: GPTConfig) -> None:
        # 初始化 nn.Module 基类,后续 ModuleDict / Linear 才会被注册。
        super().__init__()
        # 保存 config,forward、generate、checkpoint 都要读取它。
        self.config = config
        # ModuleDict 让子模块按名字组织,state_dict 里也会保留清晰前缀。
        self.transformer = nn.ModuleDict(dict(
            # wte 把 token id 映射成向量,输入 [B, T] 变成 [B, T, C]。
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            # wpe 是可学习位置向量,长度上限由 block_size 决定。
            wpe=nn.Embedding(config.block_size, config.n_embd),
            # drop 对 embedding 和 block 输出做正则化。
            drop=nn.Dropout(config.dropout),
            # h 是 Transformer block 列表,n_layer 决定重复次数。
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            # ln_f 是进入 lm_head 前的最终归一化。
            ln_f=LayerNorm(config.n_embd, bias=config.bias),
        ))
        # lm_head 把隐藏向量映射回词表 logits。
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # 权重共享让输入 embedding 与输出分类头使用同一矩阵。
        self.transformer.wte.weight = self.lm_head.weight
        # apply 会递归访问所有子模块,统一调用 _init_weights。
        self.apply(self._init_weights)
        # 残差投影层做 GPT-2 风格缩放初始化,层数越深,残差分支初始方差越小。
        for name, param in self.named_parameters():
            # c_proj.weight 是 attention/MLP 残差分支回到主干前的投影矩阵。
            if name.endswith("c_proj.weight"):
                # sqrt(2 * n_layer) 来自残差路径数量随深度增长的方差控制。
                std = 0.02 / math.sqrt(2 * config.n_layer)
                # normal_ 原地写入初始化值,不创建新的 Parameter。
                torch.nn.init.normal_(param, mean=0.0, std=std)
 
    # _init_weights 定义 Linear 和 Embedding 的默认初始化。
    def _init_weights(self, module: nn.Module) -> None:
        # Linear 权重用小方差正态分布初始化,避免训练初期 logits 过大。
        if isinstance(module, nn.Linear):
            # GPT 系列常见 std=0.02,足够小且经验稳定。
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # 只有存在 bias 时才清零,bias=False 的层没有这个参数。
            if module.bias is not None:
                # bias 初始为 0,让模型起点不带额外偏移。
                torch.nn.init.zeros_(module.bias)
        # Embedding 也用同样尺度初始化,保持输入表示和线性层尺度一致。
        elif isinstance(module, nn.Embedding):
            # embedding 表越大,初始化尺度仍由 hidden 训练稳定性决定。
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
    # idx 是 token id 张量,targets 是右移一位后的训练标签。
    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        # idx.device 决定位置张量也创建在哪个设备上。
        device = idx.device
        # b 是 batch size,t 是当前序列长度。
        b, t = idx.size()
        # t 不能超过模型创建时的 block_size,否则位置 embedding 和 mask 都不够长。
        assert t <= self.config.block_size
        # pos 是 [0, 1, ..., t-1],形状为 [T]。
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        # tok_emb 形状是 [B, T, C]。
        tok_emb = self.transformer.wte(idx)
        # pos_emb 形状是 [T, C],会广播到 batch 维。
        pos_emb = self.transformer.wpe(pos)
        # token 表示与位置表示相加,再进入 dropout。
        x = self.transformer.drop(tok_emb + pos_emb)
        # 每个 block 都包含 pre-norm attention 与 pre-norm MLP。
        for block in self.transformer.h:
            # block 不改变形状,始终保持 [B, T, C]。
            x = block(x)
        # 最终归一化让进入词表头的表示更稳定。
        x = self.transformer.ln_f(x)
        # 训练时需要所有位置 logits,因为每个位置都有 next-token 标签。
        if targets is not None:
            # logits 形状是 [B, T, vocab_size]。
            logits = self.lm_head(x)
            # view(-1, vocab) 把所有 token 位置压成分类样本。
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        # 推理时只关心最后一个位置的下一个 token 分布。
        else:
            # x[:, [-1], :] 保留时间维,输出形状仍是 [B, 1, vocab_size]。
            logits = self.lm_head(x[:, [-1], :])
            # 推理路径不计算 loss,避免无意义开销。
            loss = None
        # 返回 logits 和可选 loss,训练/推理共用同一个 forward。
        return logits, loss
数据切 batch

nanoGPT 的数据读取没有引入复杂 DataLoader,训练样本直接来自二进制 token 文件中的随机切片。这段代码值得精读,因为它把 causal LM 的输入/标签关系说得很清楚: x 是当前位置 token, y 是整体右移一位的 next-token 标签。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# data_dir 约定为 data/<dataset>,里面放 train.bin 和 val.bin。
data_dir = os.path.join("data", dataset)
 
 
# split 只能是 train 或 val,用来选择不同的二进制 token 文件。
def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    # 每次重新创建 memmap,可以避免部分平台上长时间迭代的内存持有问题。
    if split == "train":
        # train.bin 存的是连续 token id,uint16 足够容纳 GPT-2 词表。
        data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
    else:
        # val.bin 用同样格式保存验证集 token id。
        data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
 
    # ix 随机抽 batch_size 个起点,每个起点后面要能切出 block_size+1 个 token。
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # x 取 [i, i+block_size),作为模型输入。
    x = torch.stack([
        # memmap 切片先转 int64,因为 Embedding 需要 long 类型 token id。
        torch.from_numpy((data[i:i + block_size]).astype(np.int64))
        # 对 batch 里的每个随机起点都做同样切片。
        for i in ix
    ])
    # y 取 [i+1, i+1+block_size),作为每个位置的 next-token 标签。
    y = torch.stack([
        # y 和 x 等长,只是整体右移 1 个 token。
        torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64))
        # batch 内每条样本都有自己的随机起点。
        for i in ix
    ])
 
    # CUDA 路径使用 pinned memory + non_blocking,把 CPU 到 GPU 拷贝尽量异步化。
    if device_type == "cuda":
        # pin_memory 让主机内存页锁定,GPU DMA 拷贝更高效。
        x = x.pin_memory().to(device, non_blocking=True)
        # 标签同样提前搬到 GPU,避免 forward 时再阻塞。
        y = y.pin_memory().to(device, non_blocking=True)
    else:
        # CPU 或 MPS 路径不使用 pinned memory。
        x = x.to(device)
        # 标签设备必须和输入一致,否则 loss 计算会报错。
        y = y.to(device)
 
    # 返回 [B, T] 输入和 [B, T] 标签。
    return x, y
训练循环

训练主循环把评估、保存、前向、反向、更新、日志和退出条件串起来。这里先看不展开 AMP 细节的骨架,重点是训练状态如何流动。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# 先取第一个 batch,让后续循环里可以边算当前 batch 边预取下一个 batch。
X, Y = get_batch("train")
# t0 用于统计每个 iteration 的墙钟时间。
t0 = time.time()
# iter_num 是全局训练 step,会写入 checkpoint。
iter_num = 0
# local_iter_num 是当前进程生命周期内的 step,用于跳过刚启动的性能抖动。
local_iter_num = 0
# best_val_loss 用来判断当前 checkpoint 是否刷新最优验证集表现。
best_val_loss = 1e9
# raw_model 在 DDP 下指向内部原始模型,保存 state_dict 时要用它。
raw_model = model.module if ddp else model
 
 
# 主循环一直运行到 max_iters 或外部中断。
while True:
    # 当前 step 的学习率可能来自 warmup + cosine decay。
    lr = get_lr(iter_num) if decay_lr else learning_rate
    # AdamW 的每个参数组都要同步更新 lr。
    for param_group in optimizer.param_groups:
        # 直接写 param_group["lr"] 是 PyTorch optimizer 的标准动态调参方式。
        param_group["lr"] = lr
 
    # 到评估间隔时,只让主进程评估和保存,避免多卡重复写文件。
    if iter_num % eval_interval == 0 and master_process:
        # estimate_loss 会切到 eval(),抽多个 batch 求平均 loss。
        losses = estimate_loss()
        # 日志同时看 train 和 val,方便判断欠拟合或过拟合。
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        # 验证集 loss 刷新或配置要求总是保存时,写 checkpoint。
        if losses["val"] < best_val_loss or always_save_checkpoint:
            # 更新 best_val_loss,后续 checkpoint 以它作为比较基线。
            best_val_loss = losses["val"]
            # 第 0 步通常只做 sanity eval,不急着保存。
            if iter_num > 0:
                # save_training_checkpoint 负责保存模型、优化器、配置和步数。
                save_training_checkpoint(raw_model, optimizer, model_args, iter_num, best_val_loss)
 
    # eval_only 用于只跑一次评估,不进入训练更新。
    if iter_num == 0 and eval_only:
        # break 会跳出 while True,脚本结束。
        break
 
    # train_one_iteration 内部执行梯度累积、AMP、反向、更新,并返回预取好的下一批数据。
    loss, X, Y = train_one_iteration(model, optimizer, scaler, X, Y)
    # 统计当前 iteration 耗时。
    t1 = time.time()
    # dt 是一个完整训练 step 的耗时。
    dt = t1 - t0
    # 更新 t0,供下一轮计时。
    t0 = t1
    # 按 log_interval 打印训练 loss 与耗时。
    if iter_num % log_interval == 0 and master_process:
        # loss.item() 会触发 CPU/GPU 同步,因此只在日志步做。
        print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt * 1000:.2f}ms")
    # 全局 step 加 1。
    iter_num += 1
    # 本进程本地 step 加 1。
    local_iter_num += 1
    # 达到最大 step 后退出。
    if iter_num > max_iters:
        # 训练正常结束。
        break
AMP/梯度累积

AMP 与梯度累积是 nanoGPT 比 minGPT 更接近真实训练脚本的关键部分。AMP 降低显存和提升吞吐;梯度累积把多个 micro batch 的梯度加起来,模拟更大的全局 batch;DDP 下只在最后一个 micro step 同步梯度,避免无谓通信。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# ptdtype 把字符串配置映射成 PyTorch dtype。
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
# CPU 上不启用 autocast,CUDA 上按 dtype 进入混合精度上下文。
ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# fp16 需要 GradScaler;bf16 动态范围更大,通常不需要 loss scaling。
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
 
 
# 训练一个 optimizer step,内部可能包含多个 micro step。
def train_one_iteration(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: torch.cuda.amp.GradScaler,
    X: torch.Tensor,
    Y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # loss_for_log 保存最后一个 micro step 的 loss,用于日志。
    loss_for_log = None
    # micro step 循环把多个小 batch 的梯度累积到同一组参数上。
    for micro_step in range(gradient_accumulation_steps):
        # DDP 模式下,前几个 micro step 不做 all-reduce,同步留到最后一步。
        if ddp:
            # require_backward_grad_sync 是 DDP 的内部开关,nanoGPT 直接设置它来减少代码嵌套。
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        # autocast 让矩阵乘法等算子使用 fp16/bf16,LayerNorm 等敏感算子可保持更安全路径。
        with ctx:
            # logits 是当前 batch 的词表分布,loss 是 next-token 交叉熵。
            logits, loss = model(X, Y)
            # 累积梯度时要把 loss 除以累积步数,使总梯度尺度等价于大 batch 平均。
            loss = loss / gradient_accumulation_steps
        # 记录缩放后的 loss;日志时再乘回去近似原始 loss。
        loss_for_log = loss
        # 当前 batch 已经进入 GPU 计算后,立即预取下一批数据,隐藏一部分 CPU/IO 开销。
        X, Y = get_batch("train")
        # scaler.scale(loss).backward() 会在 fp16 下先放大 loss,降低梯度下溢概率。
        scaler.scale(loss).backward()
 
    # 梯度裁剪前必须先 unscale,否则裁剪的是被放大的梯度。
    if grad_clip != 0.0:
        # unscale_ 把 optimizer 管理的梯度恢复到真实尺度。
        scaler.unscale_(optimizer)
        # clip_grad_norm_ 控制梯度范数,缓和偶发梯度爆炸。
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
 
    # scaler.step 内部会在发现 inf/nan 梯度时跳过 optimizer.step。
    scaler.step(optimizer)
    # update 调整下一轮 loss scale,稳定 fp16 训练。
    scaler.update()
    # set_to_none=True 直接释放 grad tensor,通常比清零更省内存。
    optimizer.zero_grad(set_to_none=True)
    # 返回乘回累积步数后的 loss,以及循环中预取好的下一批训练数据。
    return loss_for_log * gradient_accumulation_steps, X, Y
学习率调度

nanoGPT 使用线性 warmup + cosine decay。warmup 避免训练初期大步长破坏尚未稳定的表示;cosine decay 在训练后期逐渐降低更新幅度;min_lr 给学习率留一个下限。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# get_lr 输入当前 step,输出这一 step 应设置给 optimizer 的学习率。
def get_lr(it: int) -> float:
    # warmup 阶段学习率从接近 0 线性升到 learning_rate。
    if it < warmup_iters:
        # it+1 避免第 0 步学习率为 0,warmup_iters+1 让末端不过冲。
        return learning_rate * (it + 1) / (warmup_iters + 1)
    # 训练超过衰减区间后,固定使用最小学习率。
    if it > lr_decay_iters:
        # min_lr 通常取峰值 learning_rate 的十分之一量级。
        return min_lr
    # decay_ratio 把当前 step 映射到 [0, 1] 区间。
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    # 这个 assert 能及时发现 warmup/lr_decay 配置不一致。
    assert 0 <= decay_ratio <= 1
    # cosine 系数从 1 平滑降到 0。
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    # 在 min_lr 和 learning_rate 之间插值。
    return min_lr + coeff * (learning_rate - min_lr)
generate

生成函数是自回归解码的最小实现:每次把当前上下文喂给模型,取最后一个位置的 logits,按 temperature 与 top-k 变成采样分布,采样出下一个 token,再拼回序列。它解释了大模型推理为什么是逐 token 串行的。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# no_grad 关闭梯度记录,减少推理显存和 autograd 开销。
@torch.no_grad()
# idx 是提示词 token,形状是 [B, T]。
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None):
    # 每轮循环生成 1 个新 token。
    for _ in range(max_new_tokens):
        # 上下文超过 block_size 时,只保留最后 block_size 个 token。
        idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
        # forward 在 targets=None 时只计算最后一个位置的 logits。
        logits, _ = self(idx_cond)
        # 取最后一个位置的词表 logits,并用 temperature 控制分布尖锐程度。
        logits = logits[:, -1, :] / temperature
        # top_k 不为空时,只保留概率最高的 k 个 token 候选。
        if top_k is not None:
            # v[:, [-1]] 是第 k 大 logits,用作截断阈值。
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            # 阈值以下设为 -inf,softmax 后概率为 0。
            logits[logits < v[:, [-1]]] = -float("Inf")
        # softmax 把 logits 转成概率分布。
        probs = F.softmax(logits, dim=-1)
        # multinomial 按概率采样 1 个 token;temperature=0 通常要改成 argmax。
        idx_next = torch.multinomial(probs, num_samples=1)
        # 把新 token 拼到序列末尾,下一轮作为上下文继续生成。
        idx = torch.cat((idx, idx_next), dim=1)
    # 返回包含原 prompt 和新增 token 的完整序列。
    return idx
checkpoint

checkpoint 保存的是一组可恢复训练状态:模型参数、优化器状态、模型结构参数、当前 step、最优验证 loss 和训练配置。nanoGPT 还处理了 torch.compile 可能引入的 _orig_mod. 前缀。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 保存训练状态,通常只在 master_process 上执行。
def save_training_checkpoint(
    raw_model: nn.Module,
    optimizer: torch.optim.Optimizer,
    model_args: dict,
    iter_num: int,
    best_val_loss: float,
) -> None:
    # checkpoint 是一个普通 dict,torch.save 会序列化其中的张量和标量。
    checkpoint = {
        # raw_model.state_dict() 保存模型参数和 buffer。
        "model": raw_model.state_dict(),
        # optimizer.state_dict() 保存 AdamW 的动量、方差等状态。
        "optimizer": optimizer.state_dict(),
        # model_args 保存结构参数,恢复时先重建同构模型。
        "model_args": model_args,
        # iter_num 让恢复训练从正确 step 继续。
        "iter_num": iter_num,
        # best_val_loss 让 early stopping 或 best checkpoint 逻辑不断档。
        "best_val_loss": best_val_loss,
        # config 保存完整训练配置,便于复现实验。
        "config": config,
    }
    # ckpt.pt 是 nanoGPT 默认的单文件训练状态。
    ckpt_path = os.path.join(out_dir, "ckpt.pt")
    # torch.save 写出 Python 对象和张量,适合 PyTorch 内部恢复。
    torch.save(checkpoint, ckpt_path)
 
 
# 恢复训练时,先读 checkpoint,再按保存的结构参数重建模型。
def load_training_checkpoint(out_dir: str, device: str):
    # map_location 确保 checkpoint 能加载到当前机器的目标设备。
    checkpoint = torch.load(os.path.join(out_dir, "ckpt.pt"), map_location=device)
    # checkpoint 里的 model_args 是恢复模型结构的事实来源。
    gptconf = GPTConfig(**checkpoint["model_args"])
    # 先创建同构模型,再加载权重。
    model = GPT(gptconf)
    # state_dict 是保存下来的模型参数表。
    state_dict = checkpoint["model"]
    # torch.compile 包装后保存的权重名有时会带 _orig_mod. 前缀。
    unwanted_prefix = "_orig_mod."
    # list(...) 避免遍历时修改 dict 导致迭代器失效。
    for key, value in list(state_dict.items()):
        # 只处理确实带该前缀的权重名。
        if key.startswith(unwanted_prefix):
            # 去掉前缀后再写回,匹配未 compile 的模型参数名。
            state_dict[key[len(unwanted_prefix):]] = state_dict.pop(key)
    # strict load 确保参数名和形状全部对齐。
    model.load_state_dict(state_dict)
    # optimizer 恢复需要先创建同类型 optimizer,再 load_state_dict。
    optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
    # 读回 AdamW 的动量与方差,否则恢复后相当于换了优化状态。
    optimizer.load_state_dict(checkpoint["optimizer"])
    # 返回模型、优化器和训练游标。
    return model, optimizer, checkpoint["iter_num"], checkpoint["best_val_loss"]
从 minGPT 到 nanoGPT

minGPT 更像“把 GPT 结构写清楚”的教学版;nanoGPT 更像“保留最小代码量的真实训练脚本”。两者差异体现了从模型理解到训练工程的迁移。

维度 minGPT nanoGPT
主要目标 教学化展示 GPT 模型结构、Trainer 和简单任务 用很少文件跑通 decoder-only LLM 预训练/微调/采样
模型代码 结构更直接,适合第一次读 attention、block、generate 加入 Flash Attention 路径、bias 开关、权重共享、GPT-2 权重导入
训练数据 常见小数据集或任务封装,更强调 API 清晰 直接读取二进制 token memmap,更贴近大语料预训练
训练工程 Trainer 抽象较轻,适合理解基本循环 包含 DDP、AMP、梯度累积、warmup+cosine、checkpoint、torch.compile
适合读者 刚开始理解 Transformer/GPT 结构的人 已经理解结构,想掌握 LLM 训练脚本闭环的人

精读顺序建议从 minGPT 的模型结构开始,确认 QKV、mask、block、loss 的形状;随后读 nanoGPT,把注意力放在训练脚本的状态管理:配置、数据切片、AMP、梯度累积、学习率、checkpoint 和采样。这样能把“模型怎么计算”和“训练怎么长期稳定运行”连成一条线。

Claude Code

Claude Code 的官方文档、Anthropic 工程文章、事故报道和多方公开逆向分析共同指向同一个结论:工业级 AI 编程 agent 的核心不只在模型,还在模型外部的执行控制面。真正值得学习的是工具契约、权限链、上下文预算、流式执行、任务拆分、文件编辑约束、错误恢复、沙箱和可观测性。本节只抽取适合自研 AI 编程系统借鉴的工程模式,不提供可还原 proprietary 实现的细节。

阅读边界:读设计,不复刻源码

Claude Code 相关材料可以分成四类:官方文档公开的产品机制,Anthropic 工程与产品文章,媒体和安全社区对 sourcemap 外溢事件的报道,社区对泄漏材料的逆向分析。本文把本地整理稿只作为线索,不作为唯一证据。官方文档适合确认当前产品能力;工程文章适合确认沙箱、checkpoint、SDK 等正式路线;事故报道适合确认泄漏边界;社区逆向分析适合观察工业实现的取舍。

这个边界对代码精读很重要。AI 编程 agent 的可迁移价值主要来自控制面设计,而非某段具体业务源码。本文采用“模式抽象”的写法:保留系统设计、数据流、约束点和失败恢复思路,删除专有字符串、内部代号、完整 prompt、真实文件名与可还原实现。可以用于自研系统的是方法:如何拆 prompt,如何保证工具可验证,如何在长会话中保住上下文,如何在后台任务里隔离权限。

观察对象 可学习内容 落到自研系统的形式
官方 Claude Code 文档 hooks、permissions、memory、MCP、subagents、settings 层级 产品层 API 与配置模型
Anthropic 工程与产品文章 沙箱、checkpoint、IDE 集成、Agent SDK、自治任务边界 正式产品路线与安全边界
媒体和安全社区报道 sourcemap 外溢、npm 包版本、是否涉及客户数据、供应链教训 事故边界与发布安全 checklist
社区逆向分析 prompt cache、工具池排序、compaction、AsyncGenerator loop、权限链、feature flag 线索 工程架构与运行时策略
本地整理稿 模块划分、失败注释、上下文压缩细节、可疑功能名 只作为待复核线索,不能单独定论
外部复核:事实、强证据与线索

重新检索公开网络材料后,需要把“确定事实”和“社区热议说法”分开。Claude Code 的官方文档已经足够确认 permissions、memory、hooks、subagents、MCP 等控制面;事故报道和安全分析支持 sourcemap 外溢这个事件本身;大量更戏剧化的说法,例如隐藏代号、内部 roadmap、具体 feature flag 数量、Bash 安全检查行数或某个内部功能名,只能作为社区逆向线索处理。

可信度 可写入正文的内容 写作处理
官方已证实 Claude Code 有细粒度权限规则;deny、ask、allow 有明确优先级;CLAUDE.md 和 auto memory 会进入会话上下文;hooks 可在生命周期事件上执行;subagent 有独立上下文、工具限制和权限;prompt caching 按工具、system、messages 的前缀缓存工作。 作为架构基线写入正文,可直接抽象成自研设计。
多源交叉支持 2026-03-31 前后,Claude Code npm 包的 sourcemap / mapping artifact 被公开报道可用于重构相当规模的 TypeScript 源码;公开材料普遍把它归类为发布制品外溢,而非模型权重、客户仓库或云端凭证泄露。 作为发布安全案例写入,但避免复述泄漏源码细节。
社区逆向线索 具体代码行数、文件数、feature flag 数量、内部代号、反蒸馏策略、Undercover Mode、KAIROS、Bash 检查数量等说法在不同文章中口径不完全一致。 不作为确定事实;只在“可能的设计线索”层面讨论。
本地整理稿内容 更细的模块命名、失败注释、内部流程和具体实现细节。 除非能被官方文档或多方公开材料支撑,否则不进入正文定论。
复核后的架构图:模型之外的控制面

从官方文档和公开分析交叉看,Claude Code 的可迁移架构可以拆成六个平面。模型推理只是中间一层;上下文、工具、权限、记忆、扩展和发布安全共同决定系统是否可靠。

平面 核心问题 Claude Code 给出的工程启发
上下文平面 哪些信息进入模型,哪些信息压缩,哪些信息沉淀为记忆。 稳定前缀、动态后缀、CLAUDE.md、auto memory、subagent 独立上下文共同服务 token 预算和缓存命中。
工具平面 模型如何把意图变成可执行动作。 工具需要 schema、权限检查、执行函数、结果预算和审计事件;文件编辑和 Bash 需要比普通工具更强的约束。
权限平面 哪些动作可以自动执行,哪些必须询问,哪些永远阻断。 deny 优先、ask 次之、allow 最后;高风险模式只能放在容器、虚拟机或受控沙箱里。
扩展平面 团队如何把内部工具、安全规则和工作流接进 agent。 MCP、hooks、skills、plugins、custom subagents 分别改变工具来源、生命周期控制、任务知识和执行角色。
运行时平面 长任务如何流式输出、恢复失败、并行执行和回滚。 事件流、checkpoint、工具批处理、上下文 compaction、PromptTooLong 恢复链是核心结构。
发布安全平面 闭源 CLI 如何避免把调试制品、内部字符串或 roadmap 一起发布。 sourcemap 策略、sourcesContent 裁剪、敏感字符串扫描、feature flag 裁剪和制品审计必须进入 CI。
核心形态:一个受约束的 agent loop

AI 编程 agent 的主循环可以抽象为四步:收集上下文、调用模型、执行工具、把结果写回上下文。工业实现的复杂度主要来自“每一步都必须有边界”:上下文不能无限增长,工具不能无条件执行,写文件必须可校验,失败必须可恢复。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
async def run_agent(session: SessionState, user_input: str) -> None:
    # 用户消息先进入会话状态;后续 prompt 装配会从这个状态读取。
    session.messages.append({"role": "user", "content": user_input})
 
    while True:
        # build_prompt 会装配系统规则、项目约束、工具 schema 和压缩摘要,而非只拼接字符串。
        prompt = build_prompt(session)
 
        # call_model 返回的可能是自然语言,也可能包含一个或多个工具调用。
        response = await call_model(prompt)
        session.messages.append(response.message)
 
        if not response.tool_calls:
            # 没有工具调用时,当前轮结束,输出可以交给 UI 渲染。
            return
 
        # 工具调用先进入权限与并发调度层,不直接执行。
        batches = plan_tool_batches(response.tool_calls, session.permissions)
        for batch in batches:
            results = await execute_checked_tools(batch, session)
 
            # 工具结果进入上下文前先做预算控制,避免一次 grep 或 cat 塞爆上下文。
            session.messages.extend(apply_result_budget(results))
 
        # 每轮工具执行后检查上下文预算,必要时做 micro compact 或 auto compact。
        session = compact_if_needed(session)

这段抽象代码的重点是职责分离。模型负责提出下一步动作,工具层负责把动作落到真实系统,权限层负责阻止越界,压缩层负责维持上下文可用。把这些职责混在一个巨大 prompt 里,系统很快会变成不可调试的黑箱。

Prompt 装配:稳定前缀与动态后缀

Claude Code 类产品的 prompt 往往远大于用户输入。固定系统规则、工具说明、代码风格、权限提示和已知失败模式约束都会进入请求。Anthropic prompt caching 文档明确说明,缓存对象是请求前缀,顺序覆盖 tools、system、messages,直到 cache breakpoint。因此,agent harness 的 prompt 装配应拆成稳定前缀和动态后缀:稳定前缀尽量保持逐字节一致,动态后缀承载当前目录、会话状态、最近文件、CLAUDE.md、临时规则和用户输入。

这种拆分同时服务三件事:成本控制、可维护性和行为稳定。稳定前缀越稳定,prompt cache 命中率越高;动态段越靠后,某次会话的时间戳、工作目录或临时文件列表就越不会破坏前面的缓存。官方文档还给出两个重要工程参数:默认缓存寿命是 5 分钟,另有 1 小时缓存选项;这意味着长任务循环要尽量保持热路径稳定,避免频繁重排工具和 system block。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def build_system_prompt(config: PromptConfig, session: SessionState) -> list[PromptPart]:
    # 固定规则放在最前面,版本不变时可以跨会话复用缓存。
    stable_parts = [
        load_core_instructions(config.version),
        load_code_editing_rules(config.version),
        load_tool_usage_rules(config.version),
    ]
 
    # 工具描述必须按确定性顺序注入。
    # 同一组工具如果顺序漂移,prompt cache key 会被无意义地打碎。
    stable_parts.extend(render_tools(sorted(session.visible_tools, key=lambda t: t.name)))
 
    # 显式边界让后续维护者知道:边界之后的内容会随会话变化。
    boundary = PromptPart(name="dynamic_boundary", content="--- dynamic session state ---")
 
    # 动态段承载当前会话状态;它应该尽量短,并且永远放在稳定段之后。
    dynamic_parts = [
        render_runtime_environment(session.env),
        render_project_memory(session.loaded_memory),
        render_recent_files(session.recent_files),
        render_user_message(session.pending_user_input),
    ]
 
    return stable_parts + [boundary] + dynamic_parts

这里的关键约束是“动态信息后移”。工具列表排序、MCP 工具分组、模式化 prompt 的注入位置,都应围绕缓存边界设计。一个看似普通的排序函数,在线上规模下会直接影响 API 输入成本。

查询引擎:用 AsyncGenerator 表达流式控制流

AI 编程 agent 的前端需要实时显示模型输出、工具调用、权限等待、测试日志和中断状态。普通的同步函数很难描述这种流式生命周期。社区逆向分析中反复提到的一个亮点,是把主查询过程写成可迭代事件流:模型 token、工具开始、工具完成、权限阻塞、压缩重试、最终回答都变成事件。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
async def query_stream(session: SessionState, prompt: str):
    # QueryConfig 在入口处快照,避免长循环中途读取到变化的全局配置。
    config = snapshot_query_config(session)
 
    # 先产出用户可见的开始事件,UI 层可以立即进入 loading 状态。
    yield AgentEvent(type="start", session_id=session.id)
 
    while True:
        try:
            # 模型响应采用流式读取;上层可以边收 token 边渲染。
            async for delta in call_model_stream(config, session.messages):
                if delta.type == "text":
                    yield AgentEvent(type="assistant_delta", text=delta.text)
 
                if delta.type == "tool_use":
                    # 工具调用先进入执行器,执行器再决定并行、串行或等待权限。
                    async for event in execute_tool_stream(delta.tool_call, session):
                        yield event
 
            yield AgentEvent(type="done")
            return
 
        except PromptTooLongError:
            # prompt 超限是可恢复错误,先压缩上下文再重试当前轮。
            session = await reactive_compact(session)
            yield AgentEvent(type="retry", reason="reactive_compact")
 
        except MaxOutputTokensError:
            # 输出被截断时不应直接丢弃任务;可以追加 resume 指令继续生成。
            session.messages.append(make_resume_message())
            yield AgentEvent(type="retry", reason="resume_after_output_limit")

这种写法的优点是背压清晰。终端 UI、日志系统、权限弹窗、工具执行器都消费同一条事件流;中断、重试和恢复也能作为事件明确表达。对自研系统而言,AsyncGenerator 比回调嵌套更容易测试,也更适合把长任务拆成可观察的阶段。

工具契约:用 schema 限制模型的错误空间

Claude Code 类系统通常把工具定义成“模型可见说明 + 输入 schema + 权限检查 + 执行函数”的组合。这里最有价值的设计,是在工具入口用结构化参数约束模型,压缩自由生成命令文本带来的错误空间。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class Tool:
    name: str
    description: str
    schema: dict
    read_only: bool
 
    async def check_permissions(self, args: dict, context: ToolContext) -> PermissionResult:
        raise NotImplementedError
 
    async def call(self, args: dict, context: ToolContext) -> ToolResult:
        raise NotImplementedError
 
 
class FileEditTool(Tool):
    name = "file_edit"
    read_only = False
    schema = {
        "file_path": "absolute path",
        "old_str": "exact unique text to replace",
        "new_str": "replacement text",
    }
 
    async def call(self, args: dict, context: ToolContext) -> ToolResult:
        path = normalize_path(args["file_path"])
        old = args["old_str"]
        new = args["new_str"]
 
        # 精确唯一匹配是文件编辑工具的关键约束。
        # 模型必须证明自己知道要替换的原文,不能只给一个模糊位置。
        content = await read_text(path)
        if content.count(old) != 1:
            return ToolResult.error("old_str must match exactly once")
 
        # 真正写入前可以再做权限、路径、格式化或 diff 预览。
        await write_text(path, content.replace(old, new))
        return ToolResult.ok("file updated")

old_str 的精确唯一匹配非常关键。它把“模型觉得自己知道位置”变成“工具可以验证模型知道位置”。这个设计直接降低误改文件的概率,也让失败可以自然反馈给模型重新定位。

工具池:确定性排序、延迟加载与 MCP 分区

工具越多,prompt 越长,模型的选择空间也越大。Anthropic 的 Tool Search 文档把这个问题抽象成“动态发现工具”:系统先让模型搜索工具目录,再按需加载具体 schema,减少一次性展开所有工具定义的上下文成本。Claude Code 与 MCP 场景的公开分析也反复指向同一工程目标:内建工具、MCP 工具和低频扩展工具应分区排序和延迟加载,避免外部工具增删扰动核心工具的缓存前缀。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def assemble_tool_pool(
    builtins: list[Tool],
    mcp_tools: list[Tool],
    permission: PermissionContext,
) -> list[Tool]:
    # deny 规则命中的工具在装配阶段直接移除。
    # 模型看不到被禁止工具,就不会尝试绕过或反复请求。
    allowed_builtins = [t for t in builtins if not permission.denies(t.name)]
    allowed_mcp = [t for t in mcp_tools if not permission.denies(t.name)]
 
    # 内建工具先排序,形成稳定的核心前缀。
    # MCP 工具放在后面,外部扩展变化时只影响后缀。
    return sorted(allowed_builtins, key=lambda t: t.name) + sorted(allowed_mcp, key=lambda t: t.name)
 
 
def visible_tools_for_prompt(pool: list[Tool], task: str) -> list[Tool]:
    always_load = [t for t in pool if t.always_visible]
    deferred = [t for t in pool if not t.always_visible]
 
    # ToolSearch 是一个元工具:模型先描述自己需要什么能力,再展开匹配工具。
    search_tool = make_tool_search_index(deferred)
 
    # 初始 prompt 只暴露常用工具和搜索入口,减少 token 与选择噪声。
    return always_load + [search_tool]

延迟加载工具有两个收益。第一,低频工具不占初始 prompt;第二,工具集变化对缓存前缀的影响变小。对于接入大量内部 MCP 服务的团队,这个设计比单纯增加上下文窗口更实际。

权限链:静态规则、动态判断与人工确认

代码 agent 的危险动作集中在 Bash、文件写入、网络访问、MCP 外部工具和 Git 操作。Claude Code 官方权限文档给出的核心规则是 deny、ask、allow 分层决策,deny 规则优先级最高。工业系统应把权限做成运行时控制链,避免把安全边界完全交给 prompt 自律。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def check_permission(call: ToolCall, context: PermissionContext) -> PermissionDecision:
    # 第一层:静态规则最快,适合处理明确 allow / deny / ask 的场景。
    static = match_static_policy(call, context.policy)
    if static.is_final:
        return static
 
    # 第二层:把工具调用压缩成紧凑表示,再让分类器判断风险。
    # 输入里不应包含助手长篇解释,避免模型自我合理化污染风险判断。
    risk = classify_tool_risk(
        user_intent=context.last_user_message,
        tool_name=call.name,
        tool_args=project_safe_args(call.args),
    )
    if risk.should_block:
        return PermissionDecision.deny(risk.reason)
 
    if risk.requires_confirmation:
        # 第三层:有真实影响面的动作交给用户或外部审批系统确认。
        return ask_human(call, reason=risk.reason)
 
    return PermissionDecision.allow()

权限链的工程目标是把“是否执行”从模型输出里剥离出来。模型可以建议执行,系统负责判断能否执行。对 Bash 这类影响面极大的工具,静态规则、风险分类、交互确认和审计日志都应同时存在。

Sandboxing:把权限提示升级成执行边界

权限系统解决的是“能不能执行这个动作”,沙箱解决的是“即使执行了,动作能影响到哪里”。Claude Code 官方 sandboxing 文档和 Anthropic 工程文章都强调两个边界:filesystem isolation 和 network isolation。文件系统隔离缺少网络隔离时,恶意命令仍可能把敏感文件发出去;网络隔离缺少文件系统隔离时,恶意命令仍可能读取或破坏宿主机文件。企业级 AI 编程系统应把沙箱当成默认执行边界,减少对逐条 Bash 人工确认的依赖。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def choose_execution_boundary(call: ToolCall, context: RuntimeContext) -> ExecutionBoundary:
    # 只读文件工具通常不需要进入完整沙箱,但仍要走路径权限检查。
    if call.name in {"read_file", "grep", "glob"}:
        return ExecutionBoundary.host_readonly()
 
    # Bash 是最高风险工具,默认进入受控沙箱。
    if call.name == "bash":
        risk = classify_shell_command(call.args["command"])
 
        # 会访问网络或写文件的命令需要同时约束文件系统和网络。
        if risk.writes_files or risk.uses_network:
            return ExecutionBoundary.sandbox(
                # workspace 只暴露当前项目目录,避免读取用户 home 下的密钥。
                filesystem_root=context.workspace_root,
                # 网络默认关闭,只对白名单域名或内部代理开放。
                network_policy=context.network_policy,
                # 环境变量按 allowlist 传入,避免泄露 token。
                env=context.safe_environment,
            )
 
        # 纯只读命令也可以进入轻量沙箱,减少权限弹窗。
        return ExecutionBoundary.sandbox_readonly(context.workspace_root)
 
    # MCP 外部工具按服务级别决定边界;远程工具尤其需要审计。
    return ExecutionBoundary.external_service(policy=context.mcp_policy)

这个抽象的关键点是:permission prompt 只是一层交互确认。用户批准一次危险命令后,真正限制损害半径的是沙箱、网络策略、凭证裁剪和审计。自研系统如果支持 unattended mode 或后台 agent,沙箱优先级应高于“跳过权限提示”。

Checkpoint:把大改动变成可回滚事务

Claude Code 官方 checkpointing 文档说明,系统会跟踪文件编辑工具产生的修改,每个用户 prompt 形成新的 checkpoint,checkpoint 可跨恢复会话保存并按配置清理。这个能力对 AI 编程系统很关键:模型可以一次修改多个文件,用户也需要在错误方向上快速回滚,避免手工从 Git diff 里一点点撤销。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
async def run_turn_with_checkpoint(session: SessionState, user_input: str) -> None:
    # 每个用户回合开始前创建 checkpoint,记录当前文件状态和会话元数据。
    checkpoint = await session.checkpoints.create(
        # prompt_id 用来把文件快照和用户意图绑定起来。
        prompt_id=session.next_prompt_id(),
        # 只追踪 agent 可能修改的 workspace,避免扫描整个磁盘。
        root=session.workspace_root,
    )
 
    try:
        # agent 可能执行多次 Edit、Write、Bash 和测试命令。
        await run_agent_turn(session, user_input)
 
        # 回合成功后把 checkpoint 标记为可回滚保存点。
        await checkpoint.mark_success()
 
    except Exception as exc:
        # 失败时不立即自动回滚,先把失败原因和 diff 交给用户或上层策略。
        await checkpoint.mark_failed(reason=str(exc))
        raise
 
 
async def rewind_to_checkpoint(session: SessionState, checkpoint_id: str) -> None:
    # restore 前先展示 diff,避免用户不知道会撤销哪些文件。
    diff = await session.checkpoints.diff(checkpoint_id)
    await present_rewind_diff(diff)
    # 用户确认后恢复文件系统状态和相关会话状态。
    await session.checkpoints.restore(checkpoint_id)

Checkpoint 和 Git 的职责不同。Git 负责长期版本控制和协作历史;checkpoint 负责 agent 会话内部的细粒度撤销。对于探索性重构、批量格式化、自动修复测试这类任务,checkpoint 是让模型敢于行动、用户敢于授权的基础设施。

Hooks:把软提示变成硬控制

Hooks 是 Claude Code 设计里最适合企业落地的一层。Prompt 和 CLAUDE.md 都属于模型可见指令,模型可能因为上下文拥挤或任务压力而漏遵守;hook 是运行时控制点,可以在工具执行前阻断、在工具执行后修正、在会话结束时归档,行为更接近 Git hooks 或 CI gate。官方 hooks 文档把事件分成 session、turn、tool-call 等节奏,包含 SessionStart、UserPromptSubmit、PreToolUse、PostToolUse、PostToolBatch、Notification、SubagentStart、Stop、StopFailure 等事件。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
async def run_pre_tool_hooks(call: ToolCall, context: HookContext) -> HookDecision:
    for hook in context.hooks.for_event("PreToolUse"):
        # hook 输入只包含结构化工具信息,避免把整段对话暴露给外部脚本。
        payload = {
            "tool_name": call.name,
            "tool_input": redact_sensitive_fields(call.args),
            "cwd": context.cwd,
        }
 
        result = await hook.invoke(payload)
 
        if result.action == "block":
            # 安全 hook 可以直接阻断危险命令,例如删除根目录、泄露密钥或绕过测试。
            return HookDecision.block(reason=result.reason)
 
        if result.action == "rewrite_input":
            # rewrite 适合做路径规范化、补充默认参数或改写 MCP 工具输入。
            call.args = result.new_input
 
    return HookDecision.allow(call)

Hooks 的实用模式包括:Bash 执行前做命令白名单,Edit 之后自动格式化,Stop 阶段运行测试摘要,Notification 阶段推送桌面提醒,SessionStart 阶段注入环境说明。它们把“希望模型遵守”改成“系统一定执行”。

上下文管理:预算、压缩与长期记忆

长上下文并不同于可无限堆历史。AI 编程场景里,工具结果是上下文膨胀的主要来源:读文件、grep、运行测试、打印日志都会产生大量内容。Claude Code 类系统的经验是分级处理:先做低成本清理,再做结构化摘要,最后把跨会话信息沉淀到独立记忆系统。

层级 处理对象 工程意义
工具结果预算 单次 Bash / grep / read 的长输出 保留首尾和摘要,把完整内容落盘或省略,避免单个工具结果占满上下文
Micro Compact 已经被模型消化过的旧工具结果 用占位符替换低价值原文,尽量保持 prompt 前缀稳定
Auto Compact 接近上下文上限的完整会话 调用模型生成结构化摘要,保留任务、决策、文件状态和待办事项
长期记忆 跨会话仍有价值的项目规范、用户偏好、架构约束 从上下文窗口里移出,按需召回,避免每轮都把历史全塞回 prompt
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def compact_if_needed(session: SessionState) -> SessionState:
    budget = estimate_tokens(session.messages)
 
    if budget < session.soft_limit:
        return session
 
    # 第一阶段只清理低价值工具原文,不改变关键用户意图和任务状态。
    session.messages = replace_old_tool_outputs(session.messages)
    if estimate_tokens(session.messages) < session.soft_limit:
        return session
 
    # 第二阶段才调用模型做结构化摘要,因为这一步成本更高,也可能失败。
    summary = summarize_session(
        messages=session.messages,
        required_sections=[
            "current_goal",
            "files_touched",
            "decisions_made",
            "pending_tasks",
            "known_failures",
        ],
    )
 
    # 摘要替换旧上下文后,还要把最近正在编辑的文件重新读回,避免模型失去局部细节。
    return rebuild_session_from_summary(summary, recent_files=session.recent_files)

上下文压缩要保留“继续工作的充分条件”。当前目标、已改文件、失败原因、未完成任务、用户约束和最近文件内容,比完整历史对话更重要。

错误恢复:PromptTooLong、输出截断与熔断

长会话里最常见的失败来自运行时预算耗尽:prompt 太长、输出 token 不够、工具结果过大、自动压缩失败、权限请求循环。工业级 agent 要把这些情况当成正常分支处理,避免把恢复成本转嫁给用户重开会话。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
async def call_with_recovery(session: SessionState) -> ModelResponse:
    attempts = 0
 
    while attempts < 3:
        attempts += 1
 
        try:
            return await call_model(session.messages)
 
        except PromptTooLongError:
            # 第一次先做轻量裁剪,尽量保住 prompt cache 和最近上下文。
            session = microcompact_messages(session)
 
            if still_too_large(session):
                # 第二层再做结构化摘要;摘要失败时必须有硬上限,防止无限递归。
                session = await auto_compact_with_timeout(session, seconds=30)
 
        except MaxOutputTokensError:
            # 输出截断后追加 resume 指令,保留已有回答并继续完成任务。
            session.messages.append({
                "role": "user",
                "content": "Continue from the exact point where the previous answer stopped.",
            })
 
    raise AgentRuntimeError("model call failed after bounded recovery attempts")

这里的硬上限很重要。自动压缩本身也会调用模型;如果压缩失败后继续无界重试,就会形成成本和延迟的放大器。可恢复错误需要恢复链,可恢复链也需要熔断器。

记忆系统:CLAUDE.md、规则文件与自动记忆

Claude Code 的记忆设计可以按生命周期拆开:当前会话上下文负责短期状态,CLAUDE.md 和规则文件负责团队显式知识,auto memory 负责模型从用户纠正中沉淀出的个人或项目经验。官方 memory 文档明确区分 CLAUDE.md 和 auto memory:前者由用户维护,适合编码规范、工作流和项目架构;后者由 Claude 根据纠正和偏好自动积累。三者不要混用。构建自研系统时,最容易犯的错误是把所有东西都塞进一条 system prompt,导致规则难维护、缓存难命中、过期信息难清理。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def load_memory_for_session(project: ProjectState, request: UserRequest) -> list[MemoryBlock]:
    blocks: list[MemoryBlock] = []
 
    # 团队显式规则优先加载;它们通常来自仓库里的 CLAUDE.md 或类似文件。
    blocks.extend(read_markdown_rules(project.root / "CLAUDE.md"))
 
    # 路径规则只在相关文件命中时加载,适合 monorepo。
    for rule in project.path_rules:
        if rule.matches(request.touched_paths):
            blocks.append(read_markdown(rule.path))
 
    # 自动记忆先读轻量索引,再选择少量相关正文。
    # 这样比每次注入全部历史记忆更可控。
    candidates = read_memory_headers(project.auto_memory_dir)
    selected = select_relevant_memories(candidates, request.goal, limit=5)
    blocks.extend(read_memory_body(item.path) for item in selected)
 
    return blocks

自动记忆要有类型约束和过期提示。适合写入的内容包括用户偏好、调试结论、项目约定、反复出现的坑;不适合写入的内容包括一次性日志、临时 token、错误猜测和敏感数据。长期记忆进入下一次会话时,应标注来源和时间,避免模型把旧经验当成当前事实。

并发执行:读可以并行,写必须保守

代码 agent 需要并发,否则读多个文件、跑多个只读搜索会很慢。并发策略的关键是按工具调用的实际输入判断安全性,避免给工具贴一个脱离输入的永久标签。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def partition_tool_calls(calls: list[ToolCall]) -> list[list[ToolCall]]:
    batches: list[list[ToolCall]] = []
    current_read_batch: list[ToolCall] = []
 
    for call in calls:
        if is_concurrency_safe(call):
            # 只读文件读取、glob、grep 通常可以进入同一批并行执行。
            current_read_batch.append(call)
            continue
 
        if current_read_batch:
            batches.append(current_read_batch)
            current_read_batch = []
 
        # 写文件、运行可能改状态的 shell 命令必须独占一个 batch。
        batches.append([call])
 
    if current_read_batch:
        batches.append(current_read_batch)
    return batches
 
 
def is_concurrency_safe(call: ToolCall) -> bool:
    if call.name in {"read_file", "glob", "grep"}:
        return True
    if call.name == "bash":
        # 同一个 Bash 工具要根据具体命令判断。
        # ls、pwd、git status 可并行;rm、mv、npm install、git commit 不应并行。
        return classify_shell_command(call.args["command"]).is_read_only
    return False

这个设计让系统既能利用并行,又不会把写操作打乱。更细的实现还会把 context modifier 延迟到所有并发结果结束后按原始顺序应用,避免并发工具同时修改会话状态。

Subagent:隔离上下文与隔离文件系统

复杂代码任务常常需要拆分:一个 agent 负责探索,一个 agent 负责实现,一个 agent 负责验证。Claude Code 官方 subagents 文档强调,每个 subagent 有自己的上下文窗口、custom system prompt、特定工具访问权限和独立权限。Subagent 的工程价值在于隔离上下文、工具权限、任务状态和文件修改范围,使并行协作不会互相污染。

模式 上下文关系 适合任务
fork 复制父上下文,适合复用 prompt cache 后台摘要、记忆提取、短探索任务
teammate 独立上下文,只返回结论 并行读代码、独立调研、互不污染判断的分析任务
worktree 独立上下文 + 独立文件系统视图 多个 worker 并行改代码,最后由主 agent 或用户合并

自研系统可以先实现最小可用的 teammate 模式:主 agent 只派发具体问题,subagent 只返回结构化结论。等任务开始涉及并行改文件,再引入 Git worktree 级隔离。

Feature Flag 与发布安全

多篇社区逆向材料提到 Claude Code 中存在大量 feature flag。具体数量和内部名称属于中等可信线索,正文不把它们当成确定事实;但 feature flag 对 AI 编程工具的工程意义很明确。它不仅服务灰度发布,还承担构建裁剪、内部功能隔离、实验开关、后台 agent 模式控制和风险回滚。尤其是 CLI 产品,发布包里包含什么、source map 是否包含源文本、内部字符串是否被剥离,都会变成安全边界的一部分。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def build_release_bundle(source: SourceTree, profile: BuildProfile) -> Bundle:
    # 编译期 feature gate 负责把内部能力从发布产物中物理移除。
    tree = strip_disabled_features(source, enabled=profile.public_features)
 
    # sourcemap 可以保留定位信息,但发布包不应内联完整 sourcesContent。
    maps = generate_source_maps(tree, include_source_content=False)
 
    # 发布前扫描内部代号、私有 URL、测试 token、员工专用配置等敏感字符串。
    violations = scan_for_forbidden_strings(tree, profile.forbidden_patterns)
    if violations:
        raise ReleaseBlocked(violations)
 
    # 最后再生成可发布制品,避免扫描对象和最终发布对象不一致。
    return package_for_distribution(tree, maps)

这类发布安全不属于“模型能力”,但直接决定 AI 工具是否能进入企业环境。自研 agent 只要包含内部工具、私有 MCP、实验模型代号或客户环境信息,就应把构建裁剪、敏感字符串扫描和 sourcemap 策略写进发布流水线。

对自研 AI 编程系统的启发
  • 文件编辑工具要强制可验证定位,例如精确字符串匹配、AST 范围、patch 预览,不能让模型直接自由写整段文件。
  • Bash 工具要有独立安全体系,包含命令解析、静态规则、动态风险判断、用户确认、审计和输出预算。
  • 上下文压缩要分级,先清工具结果,再做摘要,最后沉淀长期记忆。不要把所有历史长期塞进 prompt。
  • 工具列表、prompt 片段和 MCP 工具排序要稳定,因为稳定前缀直接影响 prompt cache 成本。
  • Subagent 的价值在于隔离和并行。主 agent 应该拿结构化结论,避免继承所有中间噪声。
  • Hooks 和权限系统要进入运行时控制面。企业场景不能只靠 prompt 自律来阻止危险动作。
  • 可恢复错误要有恢复链,恢复链也要有熔断器。PromptTooLong、输出截断和压缩失败都应成为显式分支。
  • 发布包要按安全产品处理:feature flag、sourcemap、敏感字符串扫描和内部工具剥离都应纳入 CI。
NER知名算法

这一节按官方源码精读三条 NER 主线: pytorch-crf 的 CRF 序列解码、Efficient GlobalPointer 的 span 矩阵打分、GLiNER 的 label-conditioned span classification。源码参照 torchcrf/__init__.py、 models/GlobalPointer.py、 gliner/modeling/span_rep.py、 gliner/modeling/scorers.py 和 gliner/decoding/decoder.py。阅读顺序是:先看模型输出张量是什么,再看 loss 如何定义,最后看 decode 如何把分数还原成实体。

CRF:torchcrf 的 log-likelihood 与 Viterbi

torchcrf.CRF 的源码核心很集中:参数包含起始转移、结束转移和标签间转移;训练时计算“真实路径分数 - 所有路径 logsumexp 归一化”;推理时用 Viterbi 动态规划找最高分标签路径。

\[\log p(\mathbf{y}\mid\mathbf{x})=\mathrm{score}(\mathbf{x},\mathbf{y})-\log\sum_{\mathbf{y}'}\exp(\mathrm{score}(\mathbf{x},\mathbf{y}'))\]

其中 \(\mathrm{score}(\mathbf{x},\mathbf{y})\) 是某一条标签路径的总分,包含每个 token 的 emission score、相邻标签的 transition score、路径起点分数和路径终点分数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# CRF 是 nn.Module,所以转移矩阵会跟随模型一起训练和保存。
class CRF(nn.Module):
    # num_tags 是 BIO/BIOES 标签总数,batch_first 控制输入维度顺序。
    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        # 标签数必须大于 0,否则转移矩阵没有定义。
        if num_tags <= 0:
            # 这里提前失败,比后面张量 shape 出错更容易定位。
            raise ValueError(f"invalid number of tags: {num_tags}")
        # 初始化 nn.Module 基类,后面 Parameter 才会被注册。
        super().__init__()
        # 保存标签数,后续校验 emissions 最后一维必须等于它。
        self.num_tags = num_tags
        # 保存输入布局约定;Transformers 通常用 batch_first=True。
        self.batch_first = batch_first
        # start_transitions[tag] 表示序列第一个标签为 tag 的起始分数。
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        # end_transitions[tag] 表示序列最后一个标签为 tag 的结束分数。
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        # transitions[i, j] 表示从标签 i 转移到标签 j 的分数。
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
        # 官方实现把三组转移参数初始化到 [-0.1, 0.1]。
        self.reset_parameters()
 
    # reset_parameters 只初始化 CRF 自己的转移参数。
    def reset_parameters(self) -> None:
        # 起始转移用小范围均匀分布,避免初始路径偏置过强。
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        # 结束转移同样小范围初始化。
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        # 标签间转移矩阵小范围初始化,训练中再学习合法转移偏好。
        nn.init.uniform_(self.transitions, -0.1, 0.1)

接入 BERT 时,BERT 负责产生每个 token 对每个标签的 emission score;CRF 负责把 token 级局部分数变成序列级路径分数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# BERT + CRF 是序列标注里最常见的组合之一。
class BertCrfForNer(nn.Module):
    # model_name 是 Hugging Face encoder 名称,num_tags 是 BIO/BIOES 标签数。
    def __init__(self, model_name: str, num_tags: int) -> None:
        # 初始化 nn.Module 基类。
        super().__init__()
        # encoder 输出每个 token 的 contextual hidden state。
        self.encoder = AutoModel.from_pretrained(model_name)
        # hidden_size 决定分类头输入维度。
        hidden_size = self.encoder.config.hidden_size
        # classifier 把 [B, L, H] 映射成 [B, L, num_tags]。
        self.classifier = nn.Linear(hidden_size, num_tags)
        # batch_first=True 后,CRF 接收 [B, L, C],与 Transformers 输出一致。
        self.crf = CRF(num_tags=num_tags, batch_first=True)
 
    # labels 为空时走解码路径,不为空时走训练 loss 路径。
    def forward(self, input_ids, attention_mask, labels=None):
        # encoder_out.last_hidden_state 形状是 [B, L, H]。
        hidden = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        # emissions 是未归一化分数,不需要先 softmax。
        emissions = self.classifier(hidden)
        # mask 标识有效 token;padding 位置不会进入路径分数。
        mask = attention_mask.bool()
        # 训练路径:最大化真实标签路径的条件 log-likelihood。
        if labels is not None:
            # torchcrf.forward 返回 log-likelihood,训练要最小化负值。
            llh = self.crf(emissions=emissions, tags=labels, mask=mask, reduction="mean")
            # 返回 dict 便于接 Trainer 或自定义训练循环。
            return {"loss": -llh}
        # 推理路径:Viterbi 返回每个样本的最高分标签 id 序列。
        paths = self.crf.decode(emissions, mask=mask)
        # paths 是 Python list,长度会按 mask 去掉 padding。
        return {"predictions": paths}

torchcrf.forward 内部的核心分成 numerator 和 denominator。numerator 只沿真实标签路径加分;denominator 对所有可能路径做 log-sum-exp,等价于 CRF 的归一化常数。

Python
1
2
3
4
5
6
7
8
9
10
# emissions: [T, B, C],tags: [T, B],mask: [T, B]。
def crf_log_likelihood(emissions, tags, mask):
    # numerator 是真实标签路径的总分。
    numerator = compute_gold_path_score(emissions, tags, mask)
    # denominator 是所有可能标签路径的 logsumexp 总分。
    denominator = compute_all_path_logsumexp(emissions, mask)
    # 条件 log-likelihood 等于真实路径分数减去归一化项。
    llh = numerator - denominator
    # 训练时通常取 -llh.mean()。
    return llh

_compute_score 对真实标签路径逐步加分。这个函数只看 gold path,不枚举其它标签序列,因此它是 CRF 分子的来源。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 这个函数对应 torchcrf.CRF._compute_score 的主体逻辑。
def compute_gold_path_score(emissions, tags, mask, start_transitions, transitions, end_transitions):
    # emissions: [T, B, C],T 是序列长度,B 是 batch,C 是标签数。
    seq_length, batch_size = tags.shape
    # mask 转成和 emissions 相同 dtype,后面可直接参与乘法。
    mask = mask.type_as(emissions)
    # 第一个 token 的路径分数由起始转移和第一个 emission 组成。
    score = start_transitions[tags[0]]
    # arange(batch_size) 让每个样本取自己的 gold tag emission。
    score = score + emissions[0, torch.arange(batch_size), tags[0]]
 
    # 从第二个 token 开始累加转移分数和 emission 分数。
    for i in range(1, seq_length):
        # tags[i - 1] -> tags[i] 是 gold path 的相邻标签转移。
        transition_score = transitions[tags[i - 1], tags[i]]
        # padding 位置 mask 为 0,不应影响真实路径分数。
        score = score + transition_score * mask[i]
        # 当前 token 的 gold tag emission 也只在有效 token 上累加。
        emission_score = emissions[i, torch.arange(batch_size), tags[i]]
        # 这里不做 softmax;CRF 全程在 score space 里计算。
        score = score + emission_score * mask[i]
 
    # 每个样本的有效长度可能不同,最后一个有效 token 要单独定位。
    seq_ends = mask.long().sum(dim=0) - 1
    # last_tags 是每个样本最后一个有效 token 的 gold tag。
    last_tags = tags[seq_ends, torch.arange(batch_size)]
    # 结束转移补上路径终点分数。
    score = score + end_transitions[last_tags]
    # 返回 [B],每个样本一个真实路径总分。
    return score

_compute_normalizer 是分母。它把所有可能路径的分数做 log-sum-exp,得到条件概率里的归一化常数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 这个函数对应 torchcrf.CRF._compute_normalizer 的主体逻辑。
def compute_all_path_logsumexp(emissions, mask, start_transitions, transitions, end_transitions):
    # T 是 token 数,B 是 batch,C 是标签数。
    seq_length = emissions.size(0)
    # score[b, c] 表示样本 b 在当前位置以标签 c 结尾的所有路径 logsumexp 分数。
    score = start_transitions + emissions[0]
 
    # 逐 token 向前推进动态规划。
    for i in range(1, seq_length):
        # [B, C, 1]:历史路径以哪个旧标签结尾。
        broadcast_score = score.unsqueeze(2)
        # [B, 1, C]:当前 token 取哪个新标签。
        broadcast_emissions = emissions[i].unsqueeze(1)
        # [B, C, C]:旧标签、新标签两两组合后的路径候选。
        next_score = broadcast_score + transitions + broadcast_emissions
        # 对旧标签维度做 logsumexp,相当于把所有来源路径累加。
        next_score = torch.logsumexp(next_score, dim=1)
        # padding token 不推进动态规划,沿用上一时刻 score。
        score = torch.where(mask[i].unsqueeze(1), next_score, score)
 
    # 所有路径都要补上结束转移。
    score = score + end_transitions
    # 对最后标签再做一次 logsumexp,得到每个样本的总归一化项。
    return torch.logsumexp(score, dim=1)

Viterbi 解码把 denominator 里的 log-sum-exp 换成 max,并保存每一步的 argmax 来源。最后从最高分结束标签向前回溯,得到全局最高分路径。这个全局路径约束是 CRF 相比逐 token softmax 的核心价值。

CRF 的创新空间主要在约束层。业务规则可以转成转移约束、非法标签惩罚、领域词典软约束或后处理检查;在标签体系稳定、误报代价高的 NER 场景里,这条路线仍然有工程价值。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 这个函数对应 torchcrf.CRF._viterbi_decode 的核心递推。
def viterbi_decode_core(emissions, mask, start_transitions, transitions, end_transitions):
    # score[b, c] 表示样本 b 当前以标签 c 结尾的最佳路径分数。
    score = start_transitions + emissions[0]
    # history 保存每一步“最佳新标签来自哪个旧标签”。
    history = []
 
    # 逐 token 做 max-product 动态规划。
    for i in range(1, emissions.size(0)):
        # [B, C, 1]:旧标签维度。
        broadcast_score = score.unsqueeze(2)
        # [B, 1, C]:新标签 emission。
        broadcast_emission = emissions[i].unsqueeze(1)
        # [B, C, C]:从任意旧标签转到任意新标签的候选分数。
        next_score = broadcast_score + transitions + broadcast_emission
        # 对旧标签取 max,indices 记录最佳来源。
        next_score, indices = next_score.max(dim=1)
        # padding 位置不更新 score。
        score = torch.where(mask[i].unsqueeze(1), next_score, score)
        # indices 用于后面从最后标签反向回溯。
        history.append(indices)
 
    # 加结束转移后,score 最大的标签就是路径终点。
    score = score + end_transitions
    # 官方源码后续按每个样本的有效长度回溯 history。
    return score, history
Efficient GlobalPointer:span 矩阵与 RoPE

Efficient GlobalPointer 的源码把实体识别建模成 token-pair 打分。给定 encoder 输出 \([B,L,H]\),模型产出 \([B,C,L,L]\),其中 \(C\) 是实体类型数,矩阵位置 \((i,j)\) 表示从 token \(i\) 到 token \(j\) 是否构成该类实体。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Efficient GlobalPointer 保留 encoder,替换掉 BIO 分类头。
class EffiGlobalPointer(nn.Module):
    # encoder 通常是 BERT/RoBERTa,ent_type_size 是实体类型数,inner_dim 是 span 打分维度。
    def __init__(self, encoder, ent_type_size: int, inner_dim: int, RoPE: bool = True):
        # 初始化 nn.Module 基类。
        super().__init__()
        # 保存文本 encoder,它负责输出 [B, L, H]。
        self.encoder = encoder
        # 实体类型数决定最终 logits 的 C 维。
        self.ent_type_size = ent_type_size
        # inner_dim 是 query/key 的内部维度。
        self.inner_dim = inner_dim
        # hidden_size 从 encoder config 读取,避免手工写死。
        self.hidden_size = encoder.config.hidden_size
        # RoPE 控制是否把位置信息旋转进 query/key。
        self.RoPE = RoPE
        # dense_1 只产生一组共享 qw/kw,参数量少于每个实体类型单独投影。
        self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
        # dense_2 给每个实体类型产生 start/end bias。
        self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2)

Efficient 版本的关键优化是共享 token-pair 主打分,再用每个实体类型的 start/end bias 区分类型。这样比原始 GlobalPointer 的 \(C\) 组 query/key 投影更省参数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# input_ids、attention_mask、token_type_ids 直接来自 tokenizer batch。
def forward(self, input_ids, attention_mask, token_type_ids):
    # 记录当前设备,位置编码和 mask 都要放在同一设备。
    self.device = input_ids.device
    # encoder 输出包含 last_hidden_state。
    context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
    # last_hidden_state 形状是 [B, L, H]。
    last_hidden_state = context_outputs.last_hidden_state
    # dense_1 输出 [B, L, 2D],交错拆成 qw 和 kw。
    outputs = self.dense_1(last_hidden_state)
    # qw 取偶数位,形状 [B, L, D]。
    qw = outputs[..., ::2]
    # kw 取奇数位,形状 [B, L, D]。
    kw = outputs[..., 1::2]
    # RoPE 把相对位置信息注入 query/key,使 span 边界对位置敏感。
    if self.RoPE:
        # pos 形状是 [B, L, D],由 sin/cos 位置编码生成。
        pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
        # cos_pos 取 cos 通道并扩展到偶奇维对齐。
        cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1)
        # sin_pos 取 sin 通道并扩展到偶奇维对齐。
        sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
        # qw2 是 qw 的旋转副本,对应二维旋转中的 [-y, x]。
        qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
        # reshape 回 [B, L, D],便于和 qw 逐元素组合。
        qw2 = torch.reshape(qw2, qw.shape)
        # RoPE 旋转后的 qw。
        qw = qw * cos_pos + qw2 * sin_pos
        # kw2 是 kw 的旋转副本。
        kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
        # reshape 回 [B, L, D]。
        kw2 = torch.reshape(kw2, kw.shape)
        # RoPE 旋转后的 kw。
        kw = kw * cos_pos + kw2 * sin_pos
    # 主 token-pair 打分,输出 [B, L, L]。
    logits = torch.einsum("bmd,bnd->bmn", qw, kw)
    # 除以 sqrt(D) 控制点积方差,避免 logits 过大。
    logits = logits / self.inner_dim ** 0.5
    # dense_2 输出 [B, L, 2C],再转成 [B, 2C, L]。
    bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
    # 偶数通道作为 start bias,奇数通道作为 end bias。
    logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None]
    # padding 和下三角区域都要屏蔽掉。
    logits = self.add_mask_tril(logits, mask=attention_mask)
    # 返回 [B, C, L, L]。
    return logits

add_mask_tril 是 GlobalPointer 工程实现里最容易漏掉的部分。padding token 不能参与 span,且实体的终点不能早于起点;这两个约束都通过大负数 mask 直接压到 logits 上。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# logits: [B, C, L, L],attention_mask: [B, L]。
def add_mask_tril(logits, attention_mask):
    # mask dtype 要和 logits 一致,避免混合精度下隐式类型问题。
    if attention_mask.dtype != logits.dtype:
        # attention_mask 从 bool/int 转成 float/bfloat16。
        attention_mask = attention_mask.type(logits.dtype)
    # start 维度 mask:起点落在 padding 上的 span 直接无效。
    start_mask = attention_mask[:, None, :, None]
    # end 维度 mask:终点落在 padding 上的 span 直接无效。
    end_mask = attention_mask[:, None, None, :]
    # 两个维度任一无效,logit 减去极大值,sigmoid 后近似 0。
    logits = logits * start_mask * end_mask - (1 - start_mask * end_mask) * 1e12
    # 下三角表示 end < start,这类区间没有实体语义。
    lower_triangle = torch.tril(torch.ones_like(logits), diagonal=-1)
    # 将非法反向 span 的分数压到极小。
    logits = logits - lower_triangle * 1e12
    # 返回已经带结构约束的 span logits。
    return logits

GlobalPointer 的标签也是 \([B,C,L,L]\)。正例 span 置为 1,所有其它合法 span 是负例。训练常用多标签交叉熵;softmax 只适合互斥类别,而同一文本里可以同时存在多个实体、多个类型、甚至嵌套实体。它的创新空间主要在 span 层:实体长度先验、类别相关阈值、边界对比学习、span hard negative mining、嵌套实体优化,都可以直接落到 \([B,C,L,L]\) 分数矩阵上。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# y_pred 和 y_true 都会被展平成 [B*C, L*L]。
def multilabel_categorical_crossentropy(y_true, y_pred):
    # 正例位置乘 -1,负例位置乘 1,统一成“希望越小越好”的形式。
    y_pred = (1 - 2 * y_true) * y_pred
    # 正例位置从负类 logsumexp 里屏蔽掉。
    y_pred_neg = y_pred - y_true * 1e12
    # 负例位置从正类 logsumexp 里屏蔽掉。
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    # 拼一个 0,保证没有正例或没有负例时 logsumexp 仍稳定。
    zeros = torch.zeros_like(y_pred[..., :1])
    # 负类候选集合追加稳定项。
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    # 正类候选集合追加稳定项。
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    # 对所有负类候选做 logsumexp。
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    # 对所有正类候选做 logsumexp。
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    # 正负两部分相加,再对 batch/entity 类型求平均。
    return (neg_loss + pos_loss).mean()
GlobalPointer 解码:从矩阵位置还原实体

解码阶段很直接:在 \([B,C,L,L]\) 中找大于阈值的位置,每个位置就是一个实体候选。工程上要把 token span 映射回原文字符区间,并处理特殊 token、子词边界和阈值。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# logits 是模型输出,threshold 通常从验证集调出来。
def decode_global_pointer(logits: torch.Tensor, threshold: float = 0.0):
    # torch.where 返回所有大于阈值的位置索引。
    batch_ids, label_ids, starts, ends = torch.where(logits > threshold)
    # entities 保存结构化候选,后续再映射到字符级 span。
    entities = []
    # zip 前先转 list,减少后续逐个 tensor.item() 的写法。
    for b, c, s, e in zip(batch_ids.tolist(), label_ids.tolist(), starts.tolist(), ends.tolist()):
        # start/end 是 token 级闭区间,是否转成字符区间取决于 tokenizer offset_mapping。
        entities.append({
            # batch 内第几个样本。
            "batch": b,
            # 预测到的实体类型 id。
            "label_id": c,
            # 实体起点 token index。
            "start": s,
            # 实体终点 token index。
            "end": e,
        })
    # 返回候选列表,业务侧可继续做阈值、去重或映射。
    return entities
GLiNER:标签文本条件下的 span 分类

GLiNER 官方源码是一组可切换架构,覆盖 uni-encoder、bi-encoder、span/token、decoder、relation extraction、ONNX 路线。NER 精读最重要的路径是 span 模型:文本 encoder 产出 token/word 表示,span representation layer 产出候选 span 表示,label encoder 或 prompt representation 产出标签表示,最后用 \(\mathrm{einsum}\) 做 span-label 匹配。

GLiNER 的创新空间主要在标签语义层。标签描述生成、同义标签融合、跨语言 label 对齐、领域标签 prompt search,都能在不改固定分类头的情况下尝试。它也适合作为开放类别召回器,再接 GlobalPointer 或 CRF 做领域精排和约束解码,把开放类别召回能力和垂直领域精度结合起来。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# predict_entities 是单条文本入口,本质上转发到 inference。
def predict_entities(
    self,
    text: str,
    labels: list[str],
    flat_ner: bool = True,
    threshold: float = 0.5,
    multi_label: bool = False,
    return_class_probs: bool = False,
    **kwargs,
):
    # text 被包装成长度为 1 的 batch。
    batch_texts = [text]
    # labels 是运行时输入,决定这一轮要抽取哪些实体类型。
    batch_labels = labels
    # inference 负责预处理、模型前向、sigmoid、decode 和后处理。
    predictions = self.inference(
        batch_texts,
        batch_labels,
        flat_ner=flat_ner,
        threshold=threshold,
        multi_label=multi_label,
        return_class_probs=return_class_probs,
        **kwargs,
    )
    # 单条输入只返回第 0 个样本的实体列表。
    return predictions[0]

GLiNER 的 span 前向可以压缩成下面这条路径。源码中的 UniEncoderSpanModel.forward 会先拿到文本与标签表示,再构造候选 span 表示,最后计算 \([B,L,K,C]\) 分数,其中 \(K\) 是最大 span 宽度。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# 这是 GLiNER span 模型 forward 的核心路径抽象。
def gliner_span_forward(
    self,
    input_ids,
    attention_mask,
    words_mask,
    text_lengths,
    span_idx,
    span_mask,
    labels=None,
    **kwargs,
):
    # get_representations 返回标签 prompt 表示、文本 word 表示和对应 mask。
    prompts_embedding, prompts_mask, words_embedding, word_mask = self.get_representations(
        input_ids,
        attention_mask,
        text_lengths,
        words_mask,
        **kwargs,
    )
    # span_idx 的第二维通常是 L*K,因此可以反推出当前 batch 的 word 长度。
    target_words = span_idx.size(1) // self.config.max_width
    # 文本表示被 pad/truncate 到和 span_idx 一致的长度。
    words_embedding, word_mask = self._fit_length(words_embedding, word_mask, target_words)
    # 无效 span 的 start/end index 置 0,避免 gather 越界或读到无意义位置。
    span_idx = span_idx * span_mask.unsqueeze(-1)
    # span_rep 形状是 [B, L, K, H]。
    span_rep = self.span_rep_layer(words_embedding, span_idx)
    # 标签数以 prompt embedding 为准;训练时还要兼容 labels 的最后一维。
    target_classes = prompts_embedding.size(1)
    # 如果 labels 更宽,说明 batch 内 label padding 更长。
    if labels is not None:
        # 取两者最大值,保证 logits 和 labels 能对齐。
        target_classes = max(target_classes, labels.size(-1))
    # prompt 表示被 pad/truncate 到 target_classes。
    prompts_embedding, prompts_mask = self._fit_length(prompts_embedding, prompts_mask, target_classes)
    # prompt_rep_layer 对标签表示再投影,得到用于匹配的 label embedding。
    prompts_embedding = self.prompt_rep_layer(prompts_embedding)
    # scores 形状是 [B, L, K, C],每个 span 对每个 label 都有一个分数。
    scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embedding)
    # 没有 labels 时就是纯推理前向。
    loss = None
    # 有 labels 时计算多标签 span classification loss。
    if labels is not None:
        # loss 内部会结合 prompts_mask 和 span_mask 屏蔽无效标签/无效 span。
        loss = self.loss(scores, labels, prompts_mask, span_mask, **kwargs)
    # 返回 logits/loss 以及中间表示,便于训练和推理复用。
    return GLiNERBaseOutput(logits=scores, loss=loss)

GLiNER 的 scorer 路线还可以用 token-label 交互理解:token 表示和 label 表示分别投影,再拼接 token、label、逐元素乘积,最后用 MLP 输出 start/end/score 等分数。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Scorer 计算 token 与 label 的兼容性。
class Scorer(nn.Module):
    # hidden_size 是 token/label 表示维度。
    def __init__(self, hidden_size: int, dropout: float = 0.1):
        # 初始化 nn.Module 基类。
        super().__init__()
        # token 投影到 2H,后面拆成两组特征。
        self.proj_token = nn.Linear(hidden_size, hidden_size * 2)
        # label 同样投影到 2H,保证和 token 可交互。
        self.proj_label = nn.Linear(hidden_size, hidden_size * 2)
        # MLP 输入是 token、label、token*label 三部分拼接。
        self.out_mlp = nn.Sequential(
            # 3H 到 4H,先扩大表达能力。
            nn.Linear(hidden_size * 3, hidden_size * 4),
            # dropout 缓和过拟合。
            nn.Dropout(dropout),
            # ReLU 引入非线性。
            nn.ReLU(),
            # 输出 3 个分数,通常对应 start、end 和整体兼容性。
            nn.Linear(hidden_size * 4, 3),
        )
 
    # token_rep: [B, L, H],label_rep: [B, C, H]。
    def forward(self, token_rep: torch.Tensor, label_rep: torch.Tensor) -> torch.Tensor:
        # 取出 batch、序列长度和隐藏维。
        batch_size, seq_len, hidden_size = token_rep.shape
        # C 是当前 batch 中参与预测的标签数。
        num_classes = label_rep.shape[1]
        # token 投影并 reshape 成 [B, L, 1, 2, H]。
        token_rep = self.proj_token(token_rep).view(batch_size, seq_len, 1, 2, hidden_size)
        # label 投影并 reshape 成 [B, 1, C, 2, H]。
        label_rep = self.proj_label(label_rep).view(batch_size, 1, num_classes, 2, hidden_size)
        # token 扩展到每个 label,变成 [2, B, L, C, H]。
        token_rep = token_rep.expand(-1, -1, num_classes, -1, -1).permute(3, 0, 1, 2, 4)
        # label 扩展到每个 token,变成 [2, B, L, C, H]。
        label_rep = label_rep.expand(-1, seq_len, -1, -1, -1).permute(3, 0, 1, 2, 4)
        # 拼接 token 特征、label 特征和逐元素乘积。
        features = torch.cat([token_rep[0], label_rep[0], token_rep[1] * label_rep[1]], dim=-1)
        # 输出 [B, L, C, 3]。
        return self.out_mlp(features)
GLiNER 解码:阈值、span 合法性与重叠处理

GLiNER 的 decoder 会先对 logits 做概率化,再找超过阈值的 span-label 候选。候选 span 还要通过长度检查,并用 greedy search 去掉不允许的重叠。这个流程对业务非常关键:阈值决定召回/精度, flat_ner 决定是否允许嵌套实体, multi_label 决定同一个 span 是否允许多个类型。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# probs_i 形状是 [L, K, C],表示单个样本所有 span/label 概率。
def decode_one_gliner_item(probs_i, tokens_i, id_to_class, threshold, flat_ner=True):
    # torch.where 找出所有超过阈值的候选。
    start_idx, width_idx, class_idx = torch.where(probs_i > threshold)
    # 当前样本文本被切成多少个 word/token。
    num_tokens = len(tokens_i)
    # span 终点是 start + width + 1,必须不超过文本长度。
    valid = (start_idx + width_idx + 1) <= num_tokens
    # 过滤非法 span 起点。
    start_idx = start_idx[valid]
    # 过滤非法 span 宽度。
    width_idx = width_idx[valid]
    # 过滤非法 label id。
    class_idx = class_idx[valid]
    # 一次性取出候选分数,减少循环里频繁 GPU 到 CPU 同步。
    scores = probs_i[start_idx, width_idx, class_idx].tolist()
    # 保存候选实体。
    spans = []
    # zip 后逐个构造结构化实体。
    for s, w, c, score in zip(start_idx.tolist(), width_idx.tolist(), class_idx.tolist(), scores):
        # end 是开区间,便于和 Python 切片保持一致。
        end = s + w + 1
        # id_to_class 把模型内部 label id 映射回业务标签文本。
        label = id_to_class.get(c + 1, f"class_{c}")
        # 保存 span、标签和置信度。
        spans.append({"start": s, "end": end, "label": label, "score": score})
    # flat_ner=True 时,后处理会按分数贪心去掉重叠 span。
    return greedy_remove_overlaps(spans) if flat_ner else spans

GLiNER 官方 decoder 的 greedy_search 会按置信度从高到低保留 span。这个设计让高置信候选优先占用区间,低置信重叠候选被丢弃;嵌套 NER 和 flat NER 的差异由重叠检测函数控制。

NER 评估必须按实体级看 precision、recall、F1。token accuracy 容易掩盖边界错误,尤其在实体很短、O 标签占比很高的数据集里。对 GLiNER 和 GlobalPointer 这类 span 模型,阈值调优也应围绕实体级指标,token 级准确率只适合作为辅助信号。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# spans 是 Span 对象列表,每个对象包含 start、end、entity_type、score。
def greedy_search(spans, flat_ner=True, multi_label=False):
    # flat NER 不允许任意重叠,nested NER 允许包含式嵌套。
    has_overlap = has_overlapping if flat_ner else has_overlapping_nested
    # 没有候选时直接返回,避免后面排序和循环。
    if not spans:
        return []
 
    # selected 保存最终实体对象。
    selected = []
    # selected_keys 保存已占用 span,用于快速检查重叠。
    selected_keys = []
    # 高分优先,保证冲突时优先保留模型更确信的实体。
    ranked_spans = sorted(spans, key=lambda x: -x.score)
 
    # 逐个尝试把候选加入最终结果。
    for span in ranked_spans:
        # entity_type 放进 key,multi_label=True 时同区间多类型可被允许。
        current = (span.start, span.end, span.entity_type)
        # 默认当前候选没有和已选实体冲突。
        blocked = False
        # 与所有已选实体检查冲突。
        for existing in selected_keys:
            # overlap 函数封装了 flat/nested/multi-label 规则。
            if has_overlap(current, existing, multi_label=multi_label):
                # 一旦冲突,当前低分候选被丢弃。
                blocked = True
                break
        # 没有冲突才写入最终结果。
        if not blocked:
            # 保存实体对象,后续返回给业务层。
            selected.append(span)
            # 保存轻量 key,供下一轮候选做重叠判断。
            selected_keys.append(current)
 
    # 输出按文本顺序排列,比按置信度排列更适合渲染和评估。
    selected.sort(key=lambda x: x.start)
    # 返回过滤后的实体列表。
    return selected
附录
常见陷阱与排障速查

这一节只覆盖 ref-6 正文里已经出现过的栈:PyTorch / Transformers / Accelerate / PEFT / TRL / DeepSpeed / vLLM / RAG 向量组件,以及常见的日志与实验跟踪工具(TensorBoard、W&B、MLflow、Langfuse)。每条都给“现象→快速检查→修复动作”。

环境与依赖(安装层)
现象 快速检查 常见根因 修复动作
ImportError: Using device_map requires Accelerate 看代码是否启用了 device_map="auto" Transformers 需要 accelerate 提供 device map / offload 运行时 pip install -U accelerate,或移除 device_map 并手动 model.to(device)
DeepSpeed 安装成功但首次训练很慢 跑环境报告/查看是否在编译 ops DeepSpeed ops 走 JIT 编译,首次运行会编译 CUDA/C++ 扩展 把编译链(gcc/g++/ninja)与 CUDA toolkit 固定到镜像;或使用预编译/缓存编译产物
bitsandbytes/FlashAttention/xFormers 安装失败 核对 torch.__version__ 与 CUDA/驱动 二进制 wheel 与 CUDA/torch 组合不匹配,退回源码编译 优先选“官方支持矩阵”内的 torch+CUDA 组合;必要时换到对应 wheel 或统一用容器镜像
huggingface_hub 下载慢/缓存爆盘 检查缓存路径与磁盘配额 默认缓存落在 home 盘;大模型/多版本重复下载 设置 HF_HOME/ HF_HUB_CACHE 到大盘;固定模型版本,避免反复下载
训练侧(脚本、分布式与 checkpoint)
现象 快速检查 常见根因 修复动作
训练 loss 正常,但 eval 指标不动或波动异常 检查 eval 集是否固定、是否数据泄漏、是否用错 metric 数据切分不稳定;指标与 loss 不一致(例如二分类用 F1 优先) 固定 eval 集与随机种子;按任务选择 monitor(生成任务常用 token acc/下游指标;分类任务多用 F1/acc)
同样配置多次训练结果差异大 检查是否固定 seed;是否启用非确定性算子 并行归约顺序、不同 kernel、随机种子未固定 固定 seed;能用 deterministic kernel 的框架启用 deterministic;记录环境/依赖版本到 run_meta
断点续训后 learning rate/step 计数异常 检查是否从正确的 checkpoint 恢复 optimizer/scheduler 只恢复了模型权重,没恢复优化器状态;或切换了 batch size/accumulation 统一用框架提供的 resume 机制;恢复后不随意改动 batch/accumulation;把超参固化在 run_meta
LoRA 训练完线上加载无效果 确认线上是否真的挂载了 adapter;是否用对 base base checkpoint 不一致;adapter 没加载或没 set_active 把 base_model_id 写入 adapter 元信息;上线前做“base+adapter”一致性 smoke test
DeepSpeed/多机训练 hang 或极慢 打开 NCCL 日志;检查网卡选择 NCCL 走错网络接口;IB/PCIe 拓扑或防火墙 设置 NCCL 环境变量并固定网卡;必要时禁用 IB/P2P 作为定位手段
NCCL 诊断(最小集)
Shell
1
2
3
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,NET
export NCCL_SOCKET_IFNAME=eth0
推理侧(vLLM / OpenAI 兼容服务)
现象 快速检查 常见根因 修复动作
服务启动成功但客户端 404 / model not found GET /v1/models 看 model 名 服务端模型名与客户端 model= 不一致 vLLM 启动加 --served-model-name,客户端统一用该名称
服务 QPS 低或频繁 OOM 看 /metrics,关注 KV cache 与并发 context 过长、KV cache 预算过大、GPU mem utilization 过高 降低最大上下文/并发;保留显存余量;把流量打到多副本;按模型大小重新估算 token budget
temperature=0 仍有轻微不一致 确认是否完全禁用采样;是否跨硬件/多实例 服务端并行归约/调度带来非确定性 禁用所有采样相关选项;尽量固定推理硬件与 kernel;把“强确定性”作为服务 SLA 单独约束
端口占用 / 启动失败 lsof -i :8000 端口被旧进程占用 停止旧进程或换端口;把启动/回滚写成脚本并纳入进程管理器
最小健康检查
Shell
1
2
curl -sf http://127.0.0.1:8000/v1/models
curl -sf http://127.0.0.1:8000/metrics | head
RAG 与向量组件(FAISS / pgvector / Qdrant / Milvus / TCVectorDB)
现象 快速检查 常见根因 修复动作
召回结果明显变差(无规律) 检查是否混入不同 embedding 模型版本 同一 collection 混用不同 embedding space 把 embedding_model_id 写入 metadata;版本迁移时重建或双写新集合
cosine 相似度排序不符合预期 检查向量是否归一化;检查 metric 设置 cosine 与 inner product/l2 使用不一致 统一策略:归一化 + inner product(或显式 cosine metric);写入与查询必须一致
pgvector 查询报错:vector 扩展不存在 SELECT extname FROM pg_extension; 未启用扩展 CREATE EXTENSION IF NOT EXISTS vector;
Qdrant 可用但安全风险 检查是否启用鉴权,端口是否暴露公网 默认 Docker QuickStart 无认证 启用 API key/鉴权;仅暴露内网;生产环境用 Helm/Cloud 并加网络策略
Milvus/TCVectorDB 连接失败 检查 endpoint/token/TLS 网络不可达、鉴权不匹配、TLS 配置问题 先用最小 client 读写验证;把 endpoint/token/TLS 作为运行时配置(环境变量/密钥管理)
框架/引擎选型原则(工程视角)
需求 优先选 理由 不适合时的信号
想要最短路径微调 LLM(SFT/DPO/GRPO) TRL + (Transformers + PEFT) + Accelerate 方法流程化、接口稳定、与 HF 生态耦合最深 需要高度自定义训练循环/复杂多任务调度,且团队已有自研训练框架
通用训练循环(分类/NER/CV)且团队多人复用 Transformers Trainer 或 PyTorch Lightning 或 MMEngine 训练工程样板收敛到统一入口;callbacks/loggers/checkpoint 更标准 训练逻辑高度非标准(例如特殊采样/复杂图结构),框架抽象反而增加阻力
大模型训练显存吃紧,需要参数/优化器分片 DeepSpeed ZeRO 或 PyTorch FSDP(常经由 Accelerate) 把显存压力从“模型/优化器/梯度”三个维度拆开 集群/网络不稳定;团队缺少分布式排障经验;checkpoint 迁移频繁
需要高吞吐服务化推理(OpenAI API 兼容) vLLM( vllm serve) continuous batching、KV cache 管理、指标/观测与并行策略更贴近生产 只需离线小批推理且依赖纯 Transformers;或模型结构/算子不被 vLLM 支持
RAG 召回(单机/单租户/延迟极低) FAISS 部署简单、延迟低、索引可控 需要多租户、过滤、持久化、高可用与在线扩缩容
RAG 召回(已有 Postgres 体系) pgvector 事务/权限/JOIN/备份沿用 Postgres,治理成本低 超大规模向量 + 高 QPS;需要专用 ANN 系统能力
RAG 召回(专用向量数据库能力) Qdrant / Milvus / 托管向量库(TCVectorDB) 过滤、持久化、分布式扩展与运维能力更完整 团队没有运维能力(自建风险高);对迁移性要求极强(托管绑定风险)
训练与推理栈术语对照
术语 训练侧含义 推理/服务侧含义 落点(代码/命令)
base / adapter base 是大权重 checkpoint;adapter 是 LoRA 等增量参数 服务端加载 base 后挂载 adapter,或提前 merge 导出 PEFT: PeftModel.from_pretrained / merge_and_unload
checkpoint 训练过程中的可恢复状态(含 optimizer/scheduler 视框架而定) 上线时通常只需要“可加载权重目录”(artifact) Transformers/TRL: output_dir/checkpoint-*
artifact(模型包) 训练导出的可交付目录 推理服务直接加载的路径 save_pretrained 目录结构
device_map 训练/推理加载时的设备放置策略 服务端决定权重落 GPU/CPU/offload 的方式 Transformers: device_map="auto"
torch_dtype 训练计算 dtype(fp16/bf16/fp32) 推理加载 dtype(影响显存与速度) Transformers: torch_dtype="auto"
TTFT / TPOT 训练不直接出现 首 token 延迟 / 每 token 延迟,衡量推理体验 vLLM: /metrics + 业务侧统计
topK / rerank 训练侧用于召回模型/排序模型的训练 RAG 检索阶段:ANN 召回 topK,reranker 取 topN 向量库 search + Cross-Encoder rerank
目录/产物/命令速查表
常见目录与产物
类型 路径模式 说明
HF Datasets / PyTorch 训练输入 data/processed/train.jsonl 离线预处理脚本产出的标准化样本文件,通常交给 datasets.load_dataset 或自定义 Dataset/ DataLoader 消费。
Trainer / TRL checkpoint outputs/runs/<run_id>/checkpoint-*/ 由 Hugging Face Trainer 或 TRL Trainer 自动产出,目录内常见 trainer_state.json、优化器状态和分步权重,用于断点续训、best checkpoint 选择和实验回溯。
DeepSpeed ZeRO 分片 checkpoint outputs/runs/<run_id>/global_step*/mp_rank_*/ 由 DeepSpeed 产出的分片状态目录,内部常见 rank 级别的模型和优化器状态文件,主要用于 ZeRO 恢复和后续权重聚合,不能直接拿给 vLLM 或纯 Transformers 推理。
PEFT adapter 目录 adapters/<task_name>/{adapter_config.json,adapter_model.safetensors} 由 PEFT 的 save_pretrained 产出,只保存 LoRA/adapter 增量参数;推理时需要先加载对应 base model,再用 PeftModel.from_pretrained 挂载。
Transformers 模型包 models/registry/model_vXXXX/{config.json,tokenizer.json,model*.safetensors} 由 Transformers 的 save_pretrained 或 PEFT merge 导出脚本产出,是最通用的可部署目录;可直接被 Transformers、vLLM、TGI 等推理框架加载。
服务加载指针 models/prod -> models/registry/model_vXXXX/ 这核心是服务编排层给 vLLM、TGI、Transformers API 服务提供的稳定加载入口;切换软链接即可完成上线与回滚。
FAISS 本地索引 indexes/faiss/<collection>.faiss 由 faiss.write_index 产出,服务于单机/嵌入式 ANN 检索;通常还要配一份外部 metadata 文件保存 doc_id、chunk_id 与原文偏移。
Qdrant 持久化目录 qdrant_storage/ 由 Qdrant Docker 或 standalone 进程维护,是 collection、payload 与索引文件的持久化卷;开发环境常直接挂到本地目录,生产环境通常挂载独立数据盘或云盘。
常用命令

命令/API/函数
accelerate config / accelerate launch train.py

说明
多卡启动训练。用于 Accelerate 路线的多卡训练启动。第一条命令生成分布式配置,第二条命令按配置拉起同一份 PyTorch 训练脚本。

示例

Shell
1
2
accelerate config
accelerate launch train.py

命令/API/函数
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json

说明
DeepSpeed 启动训练。用于 DeepSpeed 训练入口。 ds_config.json 固化 ZeRO、offload、AMP 与通信策略,适合大模型显存压缩场景。

示例

Shell
1
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json

命令/API/函数
vllm serve /abs/path/to/models/prod

说明
启动 vLLM 服务。用于启动 OpenAI-compatible 的 vLLM 服务。固定 --served-model-name 后,客户端就不需要感知底层真实模型目录。

示例

Shell
1
2
3
4
vllm serve /abs/path/to/models/prod \
  --host 0.0.0.0 --port 8000 \
  --served-model-name prod \
  --api-key token-abc123

命令/API/函数
curl -sf http://127.0.0.1:8000/v1/models / curl -sf http://127.0.0.1:8000/metrics | head

说明
检查服务与指标。最小 smoke test。第一条检查模型是否已注册,第二条确认 Prometheus 指标端点是否可用。

示例

Shell
1
2
curl -sf http://127.0.0.1:8000/v1/models
curl -sf http://127.0.0.1:8000/metrics | head

命令/API/函数
docker run -p 6333:6333 -v "$(pwd)/qdrant_storage:/qdrant/storage:z" qdrant/qdrant

说明
启动 Qdrant(开发)。用于本地开发环境拉起 Qdrant。挂载卷保存索引和 payload;默认没有鉴权,生产环境必须补安全配置。

示例

Shell
1
docker run -p 6333:6333 -v "$(pwd)/qdrant_storage:/qdrant/storage:z" qdrant/qdrant

命令/API/函数
psql -d your_db -c "CREATE EXTENSION IF NOT EXISTS vector;"

说明
启用 pgvector。用于在 PostgreSQL 中启用 pgvector 扩展。只有扩展启用后,后续的向量列、HNSW/IVFFlat 索引和相似度查询语法才可用。

示例

Shell
1
psql -d your_db -c "CREATE EXTENSION IF NOT EXISTS vector;"

命令/API/函数
pip install -U openmim / mim install mmengine

说明
安装 OpenMMLab 训练底座。用于安装 OpenMMLab 生态的统一底座。后续安装 MMDetection、MMPreTrain、MMSegmentation 等仓库时,都会复用这套基础设施。

示例

Shell
1
2
pip install -U openmim
mim install mmengine
← 人工智能知识 - 编程(一)

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code class="" title="" data-url=""> <del datetime=""> <em> <i> <q cite=""> <strike> <strong> <pre class="" title="" data-url=""> <span class="" title="" data-url="">

Related Posts

  • 吴恩达机器学习笔记
  • LangChain: Architecture, LCEL, Agents, LangGraph, Retrieval, and Production Patterns
  • 人工智能知识 - Transformers和大模型
  • 人工智能知识 - 智能体
  • 人工智能知识 - 简介

Recent Posts

  • 人工智能知识 - 编程(二)
  • 人工智能知识 - 编程(一)
  • 人工智能知识 - 智能体
  • 人工智能知识 - Transformers和大模型
  • 人工智能知识 - 主要应用领域
ABOUT ME

汪震 | Alex Wong

江苏淮安人,现居北京。目前供职于腾讯云,专注国际售后AI落地。

GitHub:gmemcc

Git:git.gmem.cc

Email:gmemjunk@gmem.cc@me.com

ABOUT GMEM

绿色记忆是我的个人网站,域名gmem.cc中G是Green的简写,MEM是Memory的简写,CC则是我的小天使彩彩名字的简写。

我在这里记录自己的工作与生活,同时和大家分享一些编程方面的知识。

GMEM HISTORY
v2.00:微风
v1.03:单车旅行
v1.02:夏日版
v1.01:未完成
v0.10:彩虹天堂
v0.01:阳光海岸
MIRROR INFO
Meta
  • Log in
  • Entries RSS
  • Comments RSS
  • WordPress.org
Recent Posts
  • 人工智能知识 - 编程(二)
    这一篇承接人工智能知识 - 编程(一)。前一篇已经梳理 AI 训练与推理编程的横向工程栈;本篇进入重点框架详解与 ...
  • 人工智能知识 - 编程(一)
    这一篇专门处理 AI 训练、微调、推理与部署中的编程栈问题。前几篇分别讲了机器学习基础、任务版图、Transfo ...
  • 人工智能知识 - 智能体
    这一篇处理模型之外的系统层问题,包括上下文工程、Harness Engineering、检索增强生成(RAG)与 ...
  • 人工智能知识 - Transformers和大模型
    这一篇聚焦现代大模型主线,内容从 Transformer 架构出发,延伸到语言模型、多模态模型、预训练与微调,以 ...
  • 人工智能知识 - 主要应用领域
    这一篇从任务视角进入现代 AI 的几个核心应用方向,重点讨论自然语言处理、计算机视觉、语音和音频处理、搜索/推荐 ...
  • 人工智能知识 - 算法和机器学习
    这一篇从常用算法进入机器学习基础概念、经典机器学习与神经网络,重点讨论“模型如何被构造、训练、评估与正则化”。前 ...
  • 人工智能知识 - 数学基础
    这一篇整理 AI 所需的数学基础,包括基础数学、线性代数、微积分与概率论统计。它回答的核心问题是:模型里的向量、 ...
  • 人工智能知识 - 简介
    这一篇作为整套 AI 总纲的导论,先回答更根本的问题,不急于进入公式和具体模型细节:什么叫智能,人工智能究竟在试 ...
  • 多语言敏感信息检测模型训练日志
    这篇文章记录一个多语言敏感信息识别项目的完整训练日志。它关注的是工程路径本身:原始 AI 合成语料如何被清洗成可 ...
  • DevPod on Kubernetes: turning devcontainer.json into a persistent remote workspace
    DevPod is an open source workspace manager ...
  • OpenClaw: Architecture, Components, and Deployment Notes
    Four Months, 343,000 Stars On November 24, 2025, ...
  • Replacing Docker Desktop with Colima on macOS
    Colima is one of the cleanest ways ...
  • Kubernetes GPU Sharing
    GPU sharing in Kubernetes depends on what ...
  • Investigating and Solving the Issue of Failed Certificate Request with ZeroSSL and Cert-Manager
    In this blog post, I will walk ...
  • A Comprehensive Study of Kotlin for Java Developers
    Introduction Purpose of the Study Understanding the Mo ...
  • LangChain: Architecture, LCEL, Agents, LangGraph, Retrieval, and Production Patterns
    LangChain is no longer best understood as ...
  • Kubernetes Migration
    Migrating a Kubernetes cluster from one cloud ...
  • Terraform: a practical guide to infrastructure as code
    Terraform is an infrastructure-as-code tool. You describ ...
TOPLINKS
  • Zitahli's blue 91 people like this
  • 梦中的婚礼 64 people like this
  • 汪静好 61 people like this
  • 那年我一岁 36 people like this
  • 为了爱 28 people like this
  • 小绿彩 26 people like this
  • 彩虹姐姐的笑脸 24 people like this
  • 杨梅坑 6 people like this
  • 亚龙湾之旅 1 people like this
  • 汪昌博 people like this
  • 2013年11月香山 10 people like this
  • 2013年7月秦皇岛 6 people like this
  • 2013年6月蓟县盘山 5 people like this
  • 2013年2月梅花山 2 people like this
  • 2013年淮阴自贡迎春灯会 3 people like this
  • 2012年镇江金山游 1 people like this
  • 2012年徽杭古道 9 people like this
  • 2011年清明节后扬州行 1 people like this
  • 2008年十一云龙公园 5 people like this
  • 2008年之秋忆 7 people like this
  • 老照片 13 people like this
  • 火一样的六月 16 people like this
  • 发黄的相片 3 people like this
  • Cesium学习笔记 90 people like this
  • IntelliJ IDEA知识集锦 59 people like this
  • Bazel学习笔记 38 people like this
  • 基于Kurento搭建WebRTC服务器 38 people like this
  • PhoneGap学习笔记 32 people like this
  • NaCl学习笔记 32 people like this
  • 使用Oracle Java Mission Control监控JVM运行状态 29 people like this
  • Ceph学习笔记 27 people like this
  • 基于Calico的CNI 27 people like this
Tag Cloud
ActiveMQ AspectJ CDT Ceph Chrome CNI Command Cordova Coroutine CXF Cygwin DNS Docker eBPF Eclipse ExtJS F7 FAQ Groovy Hibernate HTTP IntelliJ IO编程 IPVS JacksonJSON JMS JSON JVM K8S kernel LB libvirt Linux知识 Linux编程 LOG Maven MinGW Mock Monitoring Multimedia MVC MySQL netfs Netty Nginx NIO Node.js NoSQL Oracle PDT PHP Redis RPC Scheduler ServiceMesh SNMP Spring SSL svn Tomcat TSDB Ubuntu WebGL WebRTC WebService WebSocket wxWidgets XDebug XML XPath XRM ZooKeeper 亚龙湾 单元测试 学习笔记 实时处理 并发编程 彩姐 性能剖析 性能调优 文本处理 新特性 架构模式 系统编程 网络编程 视频监控 设计模式 远程调试 配置文件 齐塔莉
Recent Comments
  • 杨松涛 on snmp4j学习笔记
  • kaka on Cilium学习笔记
  • JackZhouMine on Cesium学习笔记
  • 陈黎 on 通过自定义资源扩展Kubernetes
  • qg on Istio中的透明代理问题
  • heao on 基于本地gRPC的Go插件系统
  • 黄豆豆 on Ginkgo学习笔记
  • cloud on OpenStack学习笔记
  • 5dragoncon on Cilium学习笔记
  • Archeb on 重温iptables
  • C/C++编程:WebSocketpp(Linux + Clion + boostAsio) – 源码巴士 on 基于C/C++的WebSocket库
  • jerbin on eBPF学习笔记
  • point on Istio中的透明代理问题
  • G on Istio中的透明代理问题
  • 绿色记忆:Go语言单元测试和仿冒 on Ginkgo学习笔记
  • point on Istio中的透明代理问题
  • 【Maven】maven插件开发实战 – IT汇 on Maven插件开发
  • chenlx on eBPF学习笔记
  • Alex on eBPF学习笔记
  • CFC4N on eBPF学习笔记
  • 李运田 on 念爷爷
  • yongman on 记录一次KeyDB缓慢的定位过程
  • Alex on Istio中的透明代理问题
©2005-2026 Gmem.cc | Powered by WordPress | 京ICP备18007345号-2