Menu

  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay
  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay

人工智能知识 - 编程(一)

17
Apr
2026

人工智能知识 - 编程(一)

By Alex
/ in AI
0 Comments

这一篇专门处理 AI 训练、微调、推理与部署中的编程栈问题。前几篇分别讲了机器学习基础、任务版图、Transformer 与上下文工程;这一篇转向“代码层面的真实系统”:从 NumPy、数据管线、训练框架、分布式组件,到推理引擎、向量检索、服务化接口与工程辅助库,梳理一条从实验脚本到线上推理系统的完整技术链。

语言与数值计算底座

训练与推理系统最终都会落到“把数据变成数组/张量,然后在有限内存和带宽下完成大量数值运算”。这一层底座看起来朴素,但它决定了三个硬指标:吞吐(Throughput)、内存占用(Memory Footprint)与拷贝次数(Copy Count)。NumPy/Arrow/Parquet 一类组件在工程上通常承担训练数据管线、离线特征、评测集加工与推理输入输出的基础角色。

Python 语言层与文本编码底座

AI 工程里的 Python 主要承担编排角色:读取配置、组织数据、调用数值库、封装训练入口、记录日志、连接服务和文件系统。大规模矩阵计算通常下沉到 NumPy、PyTorch、JAX、BLAS、CUDA kernel 或推理引擎;Python 层的职责是让这些组件以可复现、可检查、可恢复的方式组合起来。

这一层最常见的质量问题集中在隐式全局状态、路径不稳定、编码不明确、配置不可追溯、异常被吞掉、日志缺少样本上下文,以及文本字段在 JSON/CSV/Parquet 之间反复损坏。训练数据和推理请求进入数值层之前,先要在语言层把结构、编码和边界处理稳。

Python 在 AI 系统中的分工
层次 Python 负责 下沉组件负责
实验脚本 解析参数、固定随机种子、加载配置、组织训练/评估流程。 PyTorch/JAX 执行张量前向、反向和优化器更新。
数据预处理 读取文件、解析 JSONL、清洗文本、构造样本对象、写入 shard。 NumPy/Arrow/Polars 执行批量转换、列式扫描和向量化计算。
推理服务 请求校验、路由、批处理队列、日志、异常边界和响应封装。 vLLM/TensorRT-LLM/ONNX Runtime 执行模型推理。
评估与分析 读取预测结果、按任务聚合指标、输出报告和失败样本。 NumPy/scikit-learn/SciPy 执行统计、矩阵运算和指标计算。
类型、配置与样本对象

训练脚本里的配置对象应承担两个任务:把输入参数收口成明确 schema,并把会影响实验结果的字段写进日志和 checkpoint 元数据。简单项目可以用 dataclasses;复杂项目再升级到 Hydra、Pydantic 或框架自带配置系统。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from dataclasses import asdict, dataclass
from pathlib import Path
import json
 
 
@dataclass(frozen=True)
class TrainConfig:
    # 模型路径会影响 tokenizer、权重和 chat template,必须进入实验记录。
    model_name_or_path: str
 
    # 数据路径使用 Path,减少字符串拼接导致的跨平台路径问题。
    train_file: Path
 
    # seed 统一控制切分、采样和初始化,便于复现实验。
    seed: int = 42
 
    # batch_size 属于训练语义,不应散落在脚本多个位置。
    batch_size: int = 32
 
 
cfg = TrainConfig(
    model_name_or_path="bert-base-chinese",
    train_file=Path("data/train.jsonl"),
)
 
# asdict 把 dataclass 转成普通 dict,便于写入 JSON 元数据。
metadata = asdict(cfg)
 
# Path 属于 Python 对象,写 JSON 元数据前显式转成字符串。
metadata["train_file"] = str(metadata["train_file"])
 
# ensure_ascii=False 保留中文可读性,避免日志里出现大量 \uXXXX。
Path("runs/exp001").mkdir(parents=True, exist_ok=True)
Path("runs/exp001/config.json").write_text(
    json.dumps(metadata, ensure_ascii=False, indent=2),
    encoding="utf-8",
)
JSONL、Unicode 与文本边界

大规模 NLP/LLM 数据常用 JSONL:一行一个样本,便于流式读取、失败恢复和 shard 合并。文本字段应固定 UTF-8,读写时显式声明编码;清洗阶段只做可解释的规范化,不在数据管线里随意改写语义。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
import unicodedata
from pathlib import Path
 
 
def normalize_text(text: str) -> str:
    # NFKC 会把全角英数、兼容字符等规整到更稳定的形式。
    # 它适合搜索、去重和规则匹配;严肃标注任务要先确认不会破坏标签边界。
    text = unicodedata.normalize("NFKC", text)
 
    # 只压缩首尾空白,避免改写正文内部的格式信息。
    return text.strip()
 
 
def iter_jsonl(path: Path):
    # encoding 明确写 utf-8,避免不同机器的 locale 影响读取结果。
    with path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            # 空行直接跳过,减少人工编辑数据时留下的噪声。
            if not line.strip():
                continue
 
            try:
                row = json.loads(line)
            except json.JSONDecodeError as exc:
                # 报错带上行号,便于定位坏样本所在 shard。
                raise ValueError(f"Bad JSON at {path}:{line_no}") from exc
 
            # 文本字段在入口处统一收口,后续 tokenizer 才能面对稳定输入。
            row["text"] = normalize_text(row["text"])
            yield row
脚本入口、日志与异常边界

训练和评估脚本应把入口逻辑放在 main() 中,并在最外层保留异常栈。长期任务还应把关键配置、数据路径、样本数、依赖版本和输出目录写入日志,便于几天后追查某个 checkpoint 的来源。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
import logging
from pathlib import Path
 
 
def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
 
    # 输入文件和输出目录放在命令行参数里,便于调度系统覆盖。
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--output", type=Path, required=True)
 
    # seed 和 batch-size 直接影响实验结果,应进入命令行和日志。
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch-size", type=int, default=128)
    return parser.parse_args()
 
 
def main() -> None:
    args = parse_args()
    args.output.mkdir(parents=True, exist_ok=True)
 
    # 日志同时写控制台和文件,训练失败后仍能从输出目录恢复上下文。
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(args.output / "run.log", encoding="utf-8"),
        ],
    )
 
    logging.info("input=%s output=%s seed=%d", args.input, args.output, args.seed)
 
    # 真实训练/评估逻辑放在 main 内部,避免 import 文件时直接启动任务。
    rows = list(iter_jsonl(args.input))
    logging.info("loaded_samples=%d", len(rows))
 
 
if __name__ == "__main__":
    main()
数组与科学计算
NumPy

NumPy(Numerical Python)定义了 Python 生态里最通用的数组语义:shape / stride / dtype / broadcasting。训练脚本里大量“数据预处理、采样、拼接、统计、离线特征生成、指标计算”都在直接使用这些概念,即使模型训练本身在 PyTorch/JAX 上完成。

安装通常直接用 pip install numpy。生产/研究环境常用 conda-forge 统一 BLAS 与二进制依赖,以减少 ABI 问题。

常用API

命令/API/函数
np.array

说明
显式创建新数组,适合需要确定 dtype、并允许拷贝一份独立数据的入口。

示例

Python
1
2
3
import numpy as np
 
x = np.array([1, 2, 3], dtype=np.int32)

命令/API/函数
np.asarray

说明
尽量复用已有内存,常用于把 Python 列表或上游数组接进预处理链,同时避免不必要拷贝。

示例

Python
1
x = np.asarray([1, 2, 3], dtype=np.int32)

命令/API/函数
np.frombuffer

说明
直接把 bytes 或共享内存解释成数组视图,适合二进制数据解码和零拷贝接入。

示例

Python
1
2
buf = b"\x01\x00\x00\x00\x02\x00\x00\x00"
y = np.frombuffer(buf, dtype=np.int32)

命令/API/函数
ndarray.reshape

说明
调整视图形状,常用于把扁平 buffer 重组成 batch、序列或通道布局;多数情况下是 O(1) 视图。

示例

Python
1
a = np.arange(12).reshape(3, 4)

命令/API/函数
ndarray.transpose

说明
重排维度顺序,常见于 NHWC/NCHW、batch-first/seq-first 互换;通常只改 stride,不立即拷贝。

示例

Python
1
2
a = np.arange(12).reshape(3, 4)
at = a.transpose(1, 0)

命令/API/函数
np.moveaxis / np.swapaxes

说明
按“挪动某一维”或“交换两维”的方式改布局;在图像和语音预处理中往往比手写完整 transpose 更可读。

示例

Python
1
2
x = np.zeros((224, 224, 3), dtype=np.uint8)  # HWC
y = np.moveaxis(x, -1, 0)                    # CHW

命令/API/函数
np.concatenate

说明
沿既有轴拼接 batch 或分片结果,常见于离线特征合并;会分配新数组并产生拷贝。

示例

Python
1
2
a = np.ones((2, 4))
b = np.concatenate([a, a], axis=0)

命令/API/函数
np.stack

说明
新增一个维度后再拼接,适合把多个样本或多路特征堆成 batch 张量。

示例

Python
1
x = np.stack([np.ones(4), np.zeros(4)], axis=0)

命令/API/函数
np.random.default_rng

说明
创建现代随机数生成器对象。它比直接调用全局 np.random.* 更容易做可复现实验,也更适合把采样逻辑封装进数据预处理函数。

示例

Python
1
2
3
# 把随机状态显式收进 rng,避免函数内部偷偷污染全局随机数流。
rng = np.random.default_rng(seed=42)
batch = rng.integers(low=0, high=1000, size=(8, 128), dtype=np.int32)

命令/API/函数
ndarray.astype

说明
显式转换 dtype,常用于把离线存储格式转成训练框架需要的精度。

示例

Python
1
2
x = np.random.randn(1024).astype(np.float32)
x16 = x.astype(np.float16)

命令/API/函数
np.squeeze / np.expand_dims

说明
删除或插入长度为 1 的维度,常见于模型输入输出的 batch 维、channel 维和时间维修整。

示例

Python
1
2
3
x = np.zeros((1, 80, 300), dtype=np.float32)
y = np.squeeze(x, axis=0)
z = np.expand_dims(y, axis=0)

命令/API/函数
np.ravel / ndarray.flatten

说明
都能把数组摊平; ravel 尽量返回视图, flatten 总是分配新数组。

示例

Python
1
2
3
x = np.arange(12).reshape(3, 4).T
a = x.ravel()    # 尽量复用底层内存
b = x.flatten()  # 一定新分配

命令/API/函数
np.ascontiguousarray

说明
把非连续视图显式转成 C contiguous 布局,减少后续 kernel 的隐式拷贝和性能抖动。

示例

Python
1
2
x = np.arange(12).reshape(3, 4).T
y = np.ascontiguousarray(x)

命令/API/函数
np.asfortranarray / np.isfortran

说明
显式转成列优先(Fortran-order)布局,或检测数组是否按 Fortran contiguous 存放;用于把 “contiguous” 说准确。

示例

Python
1
2
x = np.asfortranarray(np.arange(12).reshape(3, 4))
print(np.isfortran(x))

命令/API/函数
np.copy

说明
主动复制一份独立数据,适合需要切断共享底层 buffer、避免原地修改串扰的场景。

示例

Python
1
y = np.copy(x)

命令/API/函数
ndarray.flags

说明
检查连续性、可写性等底层属性,用于排查性能问题和隐式拷贝来源。

示例

Python
1
print(x.flags["C_CONTIGUOUS"])

命令/API/函数
np.newaxis

说明
通过索引语法插入长度为 1 的维度,常用于广播对齐;和 expand_dims 表达的是同一件事。

示例

Python
1
2
mu = X.mean(axis=0)
X0 = X - mu[np.newaxis, :]

命令/API/函数
np.broadcast_to

说明
把小数组按广播规则视图扩展到目标 shape,适合调试广播布局或生成只读重复视图。

示例

Python
1
mask = np.broadcast_to(np.array([1, 0]), (4, 2))

命令/API/函数
np.pad

说明
在指定维度两侧补边,常见于序列 padding、图像边界补零和卷积前预处理。

示例

Python
1
2
x = np.array([1, 2, 3])
y = np.pad(x, (2, 1), mode="constant")

命令/API/函数
np.clip / np.where

说明
前者做范围裁剪,后者做条件选择;都高频出现于归一化、掩码和后处理。

示例

Python
1
2
score = np.clip(score, 0.0, 1.0)
label = np.where(score > 0.5, 1, 0)

命令/API/函数
np.einsum / np.dot

说明
显式写出张量收缩或矩阵乘法。做离线 attention 对照实验、相似度计算、线性投影验证时,经常比手写多重循环更清楚。

示例

Python
1
2
3
4
5
q = np.random.randn(2, 16, 64).astype(np.float32)
k = np.random.randn(2, 16, 64).astype(np.float32)
 
# "bth,bsh->bts" 表示 batch 内每个 query token 与 key token 做点积。
scores = np.einsum("bth,bsh->bts", q, k)

命令/API/函数
np.flip

说明
翻转指定维度;常用于图像增强,但它返回的往往是负 stride 视图,跨框架时需要特别小心。

示例

Python
1
2
x = np.arange(6).reshape(2, 3)
y = np.flip(x, axis=1)

命令/API/函数
ndarray.base / itemsize / nbytes / ndim

说明
排查视图关系、单元素字节数、总内存占用与维度数的最小调试集合。

示例

Python
1
print(x.base is None, x.itemsize, x.nbytes, x.ndim)

命令/API/函数
np.from_dlpack

说明
通过 DLPack 在 NumPy 与 PyTorch/JAX/CuPy 间交换数组,构建跨框架零拷贝路径。

示例

Python
1
arr = np.from_dlpack(x)
统计、索引、线性代数与 IO API 补充

命令/API/函数
np.zeros / np.ones / np.full / np.empty

说明
预分配数组。训练前处理常用它们创建 mask、缓存、特征矩阵和中间 buffer; empty 不初始化内容,只适合后续会完整覆盖的场景。

示例

Python
1
2
3
4
5
6
7
8
9
10
import numpy as np
 
batch_size = 32
seq_len = 128
 
# attention_mask 默认全 0,后面再把有效 token 位置写成 1。
attention_mask = np.zeros((batch_size, seq_len), dtype=np.int8)
 
# labels 用 -100 填充,和 PyTorch CrossEntropyLoss 的 ignore_index 对齐。
labels = np.full((batch_size, seq_len), fill_value=-100, dtype=np.int64)

命令/API/函数
np.arange / np.linspace

说明
构造索引轴、时间轴、阈值网格和绘图横轴。 arange 适合离散位置, linspace 适合固定数量的连续采样点。

示例

Python
1
2
3
4
5
# position_ids 是 Transformer 输入里常见的位置索引。
position_ids = np.arange(seq_len, dtype=np.int32)
 
# 阈值扫描常用于二分类模型选择最佳 F1 或业务收益点。
thresholds = np.linspace(0.05, 0.95, num=19, dtype=np.float32)

命令/API/函数
np.sum / np.mean / np.std / np.argmax

说明
归约操作的关键参数是 axis 和 keepdims。保留维度能让后续 broadcasting 更明确,减少 shape 靠脑补对齐。

示例

Python
1
2
3
4
5
6
7
8
9
10
# hidden 的形状是 [batch, seq, hidden]。
hidden = np.random.randn(4, 128, 768).astype(np.float32)
 
# 沿 token 维求均值,同时保留长度为 1 的 seq 维。
# keepdims=True 让 pooled 可以直接和 hidden 做广播运算。
pooled = hidden.mean(axis=1, keepdims=True)
 
# logits 的最后一维是类别维,argmax(axis=-1) 得到每个样本的预测类别。
logits = np.random.randn(4, 10).astype(np.float32)
pred = np.argmax(logits, axis=-1)

命令/API/函数
np.take / np.take_along_axis

说明
按索引从数组中取值。分类、检索和 beam search 后处理里,先得到 top-k 索引,再用 take_along_axis 取回对应分数。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
logits = np.random.randn(8, 50000).astype(np.float32)
 
# argpartition 只保证 top-k 集合正确,成本低于完整 argsort。
topk_idx = np.argpartition(logits, kth=-5, axis=-1)[:, -5:]
 
# 根据 top-k 索引取回分数,保持 [batch, k] 形状。
topk_score = np.take_along_axis(logits, topk_idx, axis=-1)
 
# 再对 k 个候选内部排序,得到真正从高到低的 top-k。
order = np.argsort(topk_score, axis=-1)[:, ::-1]
topk_idx = np.take_along_axis(topk_idx, order, axis=-1)
topk_score = np.take_along_axis(topk_score, order, axis=-1)

命令/API/函数
np.nonzero / np.argwhere / np.where

说明
三者都常用于 mask,但返回形态不同: nonzero 返回每个轴的索引元组, argwhere 返回坐标矩阵, where 可做条件选择。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
mask = np.array([[True, False, True], [False, True, False]])
 
# nonzero 适合直接用于高级索引。
rows, cols = np.nonzero(mask)
 
# argwhere 适合把坐标当作样本列表继续处理。
coords = np.argwhere(mask)
 
# where 用于按 mask 选择值,常见于 padding 位置填充大负数。
scores = np.random.randn(2, 3).astype(np.float32)
masked_scores = np.where(mask, scores, -1e9)

命令/API/函数
np.shares_memory / np.may_share_memory

说明
检查两个数组是否共享底层内存,比直接看 .base 更适合排查 view/copy。训练前处理里共享内存意味着原地修改可能影响另一个变量。

示例

Python
1
2
3
4
5
6
7
x = np.arange(12).reshape(3, 4)
view = x[:, 1:]
copy = x[[0, 2]]
 
# 切片通常共享内存,fancy indexing 通常分配新数组。
print(np.shares_memory(x, view))
print(np.shares_memory(x, copy))

命令/API/函数
np.isfinite / np.nan_to_num / np.errstate

说明
处理 NaN、Inf 和数值警告。训练前应在数据管线处阻断坏特征,避免异常值进入模型后才表现为 loss 爆炸。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
features = np.array([0.0, np.nan, np.inf, -np.inf], dtype=np.float32)
 
# isfinite 找到可以安全进入模型的数值位置。
valid_mask = np.isfinite(features)
 
# nan_to_num 把异常值收口到明确范围,便于后续记录和告警。
features = np.nan_to_num(features, nan=0.0, posinf=1e6, neginf=-1e6)
 
with np.errstate(divide="ignore", invalid="ignore"):
    # where 避免对分母为 0 的位置执行无意义除法。
    ratio = np.divide(features, 10.0, out=np.zeros_like(features), where=valid_mask)

命令/API/函数
np.linalg.norm / np.matmul / @

说明
向量归一化、余弦相似度和 embedding 检查的最小组合。检索评测、聚类前处理和向量库入库前经常先用 NumPy 做 sanity check。

示例

Python
1
2
3
4
5
6
7
8
9
query = np.random.randn(4, 768).astype(np.float32)
docs = np.random.randn(1000, 768).astype(np.float32)
 
# 按行归一化,避免向量长度主导余弦相似度。
query = query / np.linalg.norm(query, axis=1, keepdims=True)
docs = docs / np.linalg.norm(docs, axis=1, keepdims=True)
 
# [num_query, hidden] @ [hidden, num_docs] -> [num_query, num_docs]
sim = query @ docs.T

命令/API/函数
np.linalg.solve / np.linalg.lstsq

说明
线性方程与最小二乘。工程上优先 solve 或 lstsq,避免显式求逆带来的数值和性能问题。

示例

Python
1
2
3
4
5
6
7
8
A = np.array([[3.0, 1.0], [1.0, 2.0]], dtype=np.float64)
b = np.array([9.0, 8.0], dtype=np.float64)
 
# 求解 A x = b,不显式计算 inv(A)。
x = np.linalg.solve(A, b)
 
# 最小二乘适合离线校准、小规模线性回归或 sanity check。
coef, residuals, rank, singular_values = np.linalg.lstsq(A, b, rcond=None)

命令/API/函数
np.linalg.svd / np.linalg.eigh

说明
SVD 用于低秩近似、PCA 和 embedding/权重矩阵分析; eigh 适合对称矩阵或协方差矩阵。

示例

Python
1
2
3
4
5
6
7
8
X = np.random.randn(1000, 128).astype(np.float64)
 
# centered 后的 SVD 可用于观察特征矩阵的有效秩。
X0 = X - X.mean(axis=0, keepdims=True)
U, S, Vt = np.linalg.svd(X0, full_matrices=False)
 
# 取前 16 个方向构造低维表示。
X16 = X0 @ Vt[:16].T

命令/API/函数
np.save / np.load / np.savez

说明
NumPy 原生数组持久化。 .npy 保存单数组且保留 shape/dtype; .npz 适合小中规模多数组缓存。

示例

Python
1
2
3
4
5
6
7
8
9
10
input_ids = np.zeros((128, 512), dtype=np.int32)
attention_mask = np.ones((128, 512), dtype=np.int8)
labels = np.full((128,), fill_value=0, dtype=np.int64)
 
# npz 适合保存评测缓存或小型样本包。
np.savez("eval_batch.npz", input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 
# mmap_mode="r" 让大数组按需读入,多个进程可共享 OS page cache。
loaded = np.load("eval_batch.npz")
ids = loaded["input_ids"]

命令/API/函数
np.lib.format.open_memmap

说明
创建可增量写入的 .npy memory map。离线抽 embedding 时,可以按 batch 写入磁盘,避免把所有向量留在内存里。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
num_samples = 1_000_000
hidden = 768
 
# 预先声明完整 shape 和 dtype,后续按切片写入。
emb = np.lib.format.open_memmap(
    "embeddings.npy",
    mode="w+",
    dtype=np.float32,
    shape=(num_samples, hidden),
)
 
start = 0
batch_vec = np.random.randn(1024, hidden).astype(np.float32)
end = start + len(batch_vec)
 
# 每个 batch 写入固定区间,避免不断 concatenate 造成 O(n^2) 拷贝。
emb[start:end] = batch_vec
emb.flush()
NumPy 工程示例:从文本特征到模型输入

下面的例子把 NumPy 的 shape、dtype、padding、mask 和内存布局串起来。它处理的是 token id 已经生成后的阶段:如何把变长序列整理成训练框架可以稳定消费的 batch。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
 
 
def build_token_batch(sequences: list[list[int]], pad_id: int, max_len: int):
    # batch_size 由输入样本数决定,后续所有数组都围绕同一 batch 维构造。
    batch_size = len(sequences)
 
    # input_ids 用 int32 足够承载大多数词表 id,能比 int64 节省一半内存。
    input_ids = np.full((batch_size, max_len), fill_value=pad_id, dtype=np.int32)
 
    # attention_mask 用 int8 表示 0/1,存储体积小,进入框架前可再转目标 dtype。
    attention_mask = np.zeros((batch_size, max_len), dtype=np.int8)
 
    for row, seq in enumerate(sequences):
        # 截断策略必须显式写出,避免长样本偷偷撑爆 batch。
        clipped = np.asarray(seq[:max_len], dtype=np.int32)
 
        # 当前样本有效 token 数,用于同时写 input_ids 和 mask。
        length = clipped.shape[0]
 
        # 只写有效区间;padding 区间保持 pad_id。
        input_ids[row, :length] = clipped
 
        # mask 的 1 表示真实 token,0 表示 padding。
        attention_mask[row, :length] = 1
 
    return input_ids, attention_mask
SciPy

SciPy(Scientific Python)在深度学习训练主循环中出现频率不高,但它在三个位置仍很常见:离线优化与拟合(例如曲线拟合、数值优化)、稀疏矩阵与图算法(构图、归一化、谱方法)、统计分布与检验(评测与数据分析)。工程上,SciPy 更适合被当作“离线数值工具箱”,不适合作为在线训练链路的核心依赖。

常用API

命令/API/函数
scipy.optimize.minimize

说明
通用数值优化入口。做温度缩放、后处理参数拟合、校准曲线估计时,经常直接把目标函数交给它求解。

示例

Python
1
2
3
4
5
6
7
from scipy.optimize import minimize
 
def objective(w):
    # 这里把离线校准误差写成标量目标;minimize 负责外层搜索。
    return ((w[0] * logits + w[1] - labels) ** 2).mean()
 
res = minimize(objective, x0=[1.0, 0.0], method="L-BFGS-B")

命令/API/函数
scipy.sparse.csr_array

说明
压缩稀疏行格式。大规模 one-hot、图邻接矩阵和稀疏特征拼接更适合先停留在 CSR,避免过早转成 dense 阵列。

示例

Python
1
2
3
4
from scipy.sparse import csr_array
 
# 三元组形式先构图,再交给 CSR 做高效存储与乘法。
x = csr_array(([1.0, 1.0, 1.0], ([0, 1, 2], [3, 1, 0])), shape=(3, 4))

命令/API/函数
scipy.sparse.linalg.cg

说明
共轭梯度法求解稀疏线性系统。检索、图正则化和某些二次型问题的离线求解经常会走这一路。

示例

Python
1
2
3
4
from scipy.sparse.linalg import cg
 
# A 通常来自稀疏图 Laplacian 或正规方程;cg 返回近似解与收敛状态。
solution, info = cg(A, b, rtol=1e-6, atol=0.0)

命令/API/函数
scipy.fft.fft / scipy.signal.convolve / scipy.stats

说明
分别对应频域变换、经典信号卷积与统计分布/检验。语音前处理、时间序列特征和实验分析都还会用到这几类接口。

示例

Python
1
2
3
4
5
from scipy import fft, signal, stats
 
spec = fft.fft(waveform)
smoothed = signal.convolve(score, kernel, mode="same")
z = stats.norm.cdf(1.96)
数组元信息与计算语义

训练与推理中常见的性能与正确性问题,经常来自数组元信息被误解:隐式拷贝、错误广播或错误 dtype 会在数据规模上来后被迅速放大。下列四个概念决定了“这块数据在内存里是什么形状、如何被解释、算子如何访问”。

shape

shape 是每一维的长度。训练数据的 shape 规划通常先于模型:batch 维、序列维、通道维的放置会直接影响 broadcasting、拼接策略与 kernel 访问模式。

Python
1
2
3
import numpy as np
 
x = np.zeros((batch, seq_len, hidden), dtype=np.float32)
布局缩写与轴顺序

很多代码不会写“第 0 维是 batch、第 1 维是 channel”,会直接写缩写。视觉、音频和 ONNX/推理框架里最常见的是:

缩写 含义 典型场景
HWC Height / Width / Channel 单张图像在 OpenCV / PIL / NumPy 中的常见布局
CHW Channel / Height / Width 单张图像进入深度学习框架前常转成此布局
NHWC / BHWC Batch / Height / Width / Channel TensorFlow、部分 ONNX 图和前处理流水线常见
NCHW / BCHW Batch / Channel / Height / Width PyTorch 和多数卷积 kernel 的默认语义布局

channels_first 与 channels_last 描述的是通道维放在哪一侧。它们首先是轴顺序约定,其次才会进一步牵涉到底层 stride 和 memory format。

Python
1
2
3
img = np.zeros((224, 224, 3), dtype=np.uint8)    # HWC
x = np.moveaxis(img, -1, 0)                      # CHW
xb = np.expand_dims(x, axis=0).astype(np.float32)  # BCHW
stride

stride 描述“沿每一维移动 1 步,需要在底层 buffer 上跳过多少字节”。它解释了为什么很多 reshape/transpose 是 O(1) 视图,以及为什么某些看似简单的切片会导致后续算子不得不拷贝成 contiguous。stride 也是 DLPack 协议定义的核心之一。

Python
1
2
3
4
5
6
7
import numpy as np
 
a = np.arange(12, dtype=np.int32).reshape(3, 4)
print(a.shape, a.strides)    # (3, 4) (16, 4)  以 int32 为例,步长单位是字节
 
at = a.T
print(at.shape, at.strides)  # (4, 3) (4, 16)  转置后 stride 对调
contiguous、order 与 view/copy

工程里最容易混淆的是:shape 一样,不代表内存布局一样;“没有显式写 copy”,也不代表没有分配新内存。这里至少要区分四件事:

  • view:只改元信息(shape/stride/offset),底层 buffer 仍共享。
  • copy:分配新内存并写入数据,也可称 materialize。
  • C contiguous:按 row-major 方式连续存放, flags["C_CONTIGUOUS"] 为真。
  • F contiguous:按 column-major 方式连续存放, flags["F_CONTIGUOUS"] 为真。

切片通常返回 view,而花式索引(fancy indexing)和布尔索引通常会 materialize。 transpose、 moveaxis 往往只是改 stride; concatenate、 stack、 flatten 则更常真的分配新数组。

Python
1
2
3
4
5
6
x = np.arange(12).reshape(3, 4)
y = x[:, 1:]           # 典型 view
z = x[[0, 2]]          # 典型 copy(fancy indexing)
 
print(y.base is x, z.base is x)
print(x.flags["C_CONTIGUOUS"], x.flags["F_CONTIGUOUS"])

order="C"、 order="F"、 order="K" 等参数,会影响 reshape/ravel/copy 时如何解释或保留现有内存顺序。多数 AI 工程默认围绕 C contiguous 工作;只有明确知道下游需要列优先布局时,才主动引入 Fortran-order。

dtype

dtype 决定了每个元素的解释方式与字节数。训练与推理中,dtype 的作用不止“精度高低”,还包括:IO 体积、缓存命中率、向量化指令路径、以及与下游框架的类型兼容性。实践上常见的约束是:数据管线侧用更紧凑的整型/字节型存储,进入训练前再一次性转换到框架需要的 dtype。

Python
1
2
3
4
import numpy as np
 
# 例:原始 token id 通常用 int32 或更小的无符号整型存储
ids = np.array([1, 2, 3, 4], dtype=np.int32)
broadcasting

broadcasting 是“不同 shape 的数组做逐元素运算时,如何对齐维度并隐式扩展”。它是把 Python 循环消掉的关键机制,但也可能引入隐藏的大中间张量或错误对齐。广播规则的工程实践通常围绕两件事:显式插入维度(None/newaxis)与显式对齐最后几维。

Python
1
2
3
4
5
6
7
8
9
10
import numpy as np
 
# (B, T, H) - (H,) -> (B, T, H)
X = np.random.randn(2, 3, 4).astype(np.float32)
mu = X.mean(axis=(0, 1))           # (H,)
X0 = X - mu                        # broadcasting
 
# 显式插维更直观
mu2 = mu[None, None, :]            # (1, 1, H)
X1 = X - mu2

广播只在“维度相等”或“其中一边等于 1”时成立。写复杂代码时,推荐先把意图写成显式插维,让每个对齐维度都能从代码里直接读出来。

Python
1
2
3
img = np.zeros((2, 224, 224, 3), dtype=np.float32)  # NHWC
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
img = img - mean[None, None, None, :]
表格与列式数据
Pandas

Pandas 在训练/推理工程里的定位更接近“数据分析与小中规模表格处理”。它擅长做数据清洗、统计分析、对齐 join、特征表合并,以及输出可审计的中间结果(CSV/Parquet)。当数据规模接近或超过内存时,工程上通常会转向 Arrow/Polars 的流式与列式路径。

常用API

命令/API/函数
pd.read_csv

说明
训练前最常见的“原始表输入”入口。高频实践是把 dtype、时间列解析和缺失值策略显式写死,避免不同机器自动推断出不同 schema。

示例

Python
1
2
3
4
5
6
7
import pandas as pd
 
df = pd.read_csv(
    "events.csv",
    dtype={"user_id": "int64", "label": "int8"},
    parse_dates=["ts"],
)

命令/API/函数
pd.read_parquet

说明
读取 Parquet 到 DataFrame,适合分析型脚本和中小规模表处理;大规模数据更适合转向 Polars 或 Arrow dataset。

示例

Python
1
2
import pandas as pd
df = pd.read_parquet("train.parquet")

命令/API/函数
DataFrame.merge / merge_asof

说明
把用户表、曝光表、标签表按键或按时间邻近对齐,是特征表组装里的主入口。 merge_asof 特别适合“找某个时间点之前最近一次状态”的时序特征。

示例

Python
1
2
3
4
5
6
7
8
9
10
feat = clicks.merge(users, on="user_id", how="left")
 
# merge_asof 要求时间列先排序,再按最近历史状态对齐。
feat = pd.merge_asof(
    feat.sort_values("ts"),
    profile.sort_values("ts"),
    on="ts",
    by="user_id",
    direction="backward",
)

命令/API/函数
DataFrame.groupby().agg()

说明
按实体聚合统计特征。它是表格特征工程里最常见的“从明细表到样本表”的变换。

示例

Python
1
2
3
4
user_feat = (
    df.groupby("user_id", as_index=False)
      .agg(clicks=("clicked", "sum"), avg_score=("score", "mean"))
)

命令/API/函数
pd.to_datetime / pd.date_range

说明
统一时间列语义。训练集切分、回测窗口和特征对齐如果没有先把时间转成明确 dtype,后面几乎一定会出错。

示例

Python
1
2
df["ts"] = pd.to_datetime(df["ts"], utc=True)
calendar = pd.date_range("2026-01-01", periods=7, freq="D", tz="UTC")

命令/API/函数
DataFrame.to_parquet

说明
把清洗后的中间表、特征表或评测样本落盘为列式文件,便于后续批处理复用。

示例

Python
1
df.to_parquet("out.parquet", index=False)

命令/API/函数
DataFrame.to_numpy

说明
把表格列送进 NumPy/模型前处理链。copy=False 不保证零拷贝,混合 dtype 往往仍会触发类型提升和拷贝。

示例

Python
1
x = df[["a", "b"]].to_numpy(dtype="float32", copy=False)
Polars

Polars 的优势在于其 lazy 执行与 streaming:把计算表达成查询计划,先做优化(projection/predicate pushdown),再以批方式执行。对训练数据预处理而言,这意味着可以在不把全量数据 materialize 到内存的情况下完成筛选、投影、采样、分桶、写回 Parquet。

常用API

命令/API/函数
pl.scan_parquet

说明
lazy 方式扫描 Parquet,不立即 materialize

示例

Python
1
2
import polars as pl
lf = pl.scan_parquet("data/*.parquet")

命令/API/函数
pl.col

说明
Polars 表达式系统的核心入口。列选择、类型转换、条件分支和聚合,几乎都从 pl.col(...) 开始。

示例

Python
1
2
3
4
5
6
expr = (
    pl.col("score")
      .cast(pl.Float32)
      .fill_null(0.0)
      .alias("score_f32")
)

命令/API/函数
LazyFrame.with_columns / group_by().agg()

说明
在 lazy 计划里做表达式变换和聚合。很多“加特征列再按实体汇总”的任务都能在这一层一次写完。

示例

Python
1
2
3
4
5
agg = (
    lf.with_columns(pl.col("score").cast(pl.Float32))
      .group_by("user_id")
      .agg(pl.col("score").mean().alias("avg_score"))
)

命令/API/函数
LazyFrame.collect / sink_parquet

说明
前者把 lazy 计划真正执行成内存结果,后者把结果直接落到 Parquet。大规模离线预处理更常优先 sink_parquet,避免中间结果全落进 Python 进程内存。

示例

Python
1
2
3
4
5
6
7
out = (
    lf.filter(pl.col("lang") == "zh")
      .select(["text", "label"])
      .collect(streaming=True)
)
 
lf.sink_parquet("out.parquet")
PyArrow

PyArrow 是 Python 对 Arrow 内存格式与生态能力的主要入口。对 AI 训练/推理工程而言,Arrow 的核心价值是统一列式内存表示,并在 Pandas、Parquet、HF Datasets 与各种 IPC 路径之间提供高效桥梁。训练数据若最终要落成 Arrow/Parquet,通常建议在预处理阶段就尽量保留 Arrow Table / RecordBatch 语义,避免频繁在 Python 对象列表与 DataFrame 之间来回转换。

常用API

命令/API/函数
pa.table / pa.array

说明
把 Python 容器显式转成 Arrow 列式对象。只有真正进入 Arrow 数组或表,后续 schema、Parquet、IPC 和 dataset 能力才会完整接上。

示例

Python
1
2
3
4
import pyarrow as pa
 
ids = pa.array([1, 2, 3], type=pa.int32())
tbl = pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]})

命令/API/函数
pa.schema / pa.field

说明
显式声明列名和类型契约。训练数据一旦要跨作业、跨机器和跨语言复用,schema-first 往往比“让库自动猜类型”稳定得多。

示例

Python
1
2
3
4
schema = pa.schema([
    pa.field("text", pa.string()),
    pa.field("label", pa.int8()),
])

命令/API/函数
pyarrow.parquet.read_table

说明
读 Parquet 为 Table

示例

Python
1
2
import pyarrow.parquet as pq
tbl = pq.read_table("train.parquet", columns=["text", "label"])

命令/API/函数
pyarrow.dataset.dataset

说明
把一批分区文件组织成统一 dataset 入口。训练语料若已经按日期、语言、split 分目录存放,这通常比手写 glob 再逐文件读取更干净。

示例

Python
1
2
3
import pyarrow.dataset as ds
 
train_ds = ds.dataset("corpus/", format="parquet", partitioning="hive")

命令/API/函数
pyarrow.compute

说明
对 Arrow 列直接做向量化变换,避免为了一个简单筛选或 cast 把数据先搬回 Pandas。

示例

Python
1
2
3
4
import pyarrow.compute as pc
 
mask = pc.equal(tbl["label"], 1)
pos = tbl.filter(mask)

命令/API/函数
pyarrow.parquet.write_table

说明
把列式结果稳定写回 Parquet,作为下游训练或评估作业的输入产物。

示例

Python
1
2
3
import pyarrow.parquet as pq
 
pq.write_table(tbl, "train.parquet")

命令/API/函数
Table.to_pandas

说明
转 Pandas

示例

Python
1
df = tbl.to_pandas()
Parquet

Parquet 是面向分析与批处理的列式文件格式。训练数据落盘选择 Parquet 的理由通常是:压缩比高、列裁剪成本低、能按列读取并减少 IO、天然支持 row group 作为大文件分块单位。工程上最常见的实践是:把可训练字段放在少数列里,并显式按任务选择 columns 读取,避免把无关字段搬进内存。

Python
1
2
3
4
5
6
import pyarrow.parquet as pq
 
tbl = pq.read_table(
  "train.parquet",
  columns=["text", "label"],   # 列裁剪
)
PyArrow IPC 与 memory_map

IPC(Inter-Process Communication)格式用于把 Arrow 的内存表示序列化为文件或流,并支持高效读写。对训练数据管线而言,一个关键工程点是:若输入源支持零拷贝读取(例如 memory map),则读出来的 batch 可以保持零拷贝路径,从而显著降低 CPU 端的内存分配与拷贝开销。

Python
1
2
3
4
5
6
7
8
9
10
import pyarrow as pa
 
# 以 memory map 方式打开文件,避免额外 read() 复制
source = pa.memory_map("dataset.arrow", mode="r")
reader = pa.ipc.open_file(source)
 
# IPC 文件通常按 RecordBatch 组织;这里只取第 0 个 batch 演示零拷贝读取路径。
batch0 = reader.get_batch(0)
# 需要和下游 Arrow API 对接时,再把若干 batch 重新拼成 Table。
tbl = pa.Table.from_batches([batch0])
序列化与权重格式
通用序列化格式

训练与推理系统里最常见的序列化对象是:配置、元数据、索引与权重。配置层常见 JSON/YAML;权重层需要关注安全性与加载速度;跨进程/跨服务通信则常用 protobuf 这类 IDL 驱动格式。

JSON

JSON 适合可读性强的元数据与小体量配置,常用于数据集 manifest、评测记录与简单索引。

YAML

YAML 常用于训练配置,但它过于通用,本文只把它视为“配置载体”。具体配置系统(Hydra/OmegaConf)在后续章节展开。

pickle

pickle 能序列化 Python 对象,但它不适合用于不可信来源的权重与模型文件,因为反序列化会执行对象构造逻辑。工程上如果需要“安全的张量权重格式”,通常会优先选择 safetensors。

protobuf

protobuf 的优势是 schema 驱动与跨语言:用 .proto 定义消息结构,由 protoc 生成多语言代码。它常用于模型服务、日志/Tracing、任务队列与数据交换协议。

  1. 定义 schema,例如在 .proto 中声明 message Foo { ... }。
  2. 运行 protoc --python_out=. foo.proto 生成 Python 绑定代码。
  3. 在 Python 中导入 foo_pb2,再按生成的消息类读写数据。
safetensors

safetensors 是面向模型权重的安全、快速格式,设计目标是替代基于 pickle 的不安全权重存储。它的工程优势主要体现在三点:加载速度、零拷贝读取路径、以及避免反序列化执行任意代码。若训练产物需要在多环境分发或上线部署,safetensors 往往是默认优先选项。

常用API

命令/API/函数
safetensors.torch.save_file

说明
保存 tensor dict

示例

Python
1
2
3
4
5
from safetensors.torch import save_file
import torch
 
tensors = {"w": torch.randn(2, 3)}
save_file(tensors, "model.safetensors")

命令/API/函数
safetensors.torch.load_file

说明
加载为 CPU tensor dict

示例

Python
1
2
3
from safetensors.torch import load_file
 
tensors = load_file("model.safetensors")

命令/API/函数
safetensors.safe_open

说明
按需读取,支持只取部分 key

示例

Python
1
2
3
4
from safetensors import safe_open
 
with safe_open("model.safetensors", framework="pt", device="cpu") as f:
    w = f.get_tensor("w")
GGUF

GGUF(GGML Universal File)是 llama.cpp 生态的权重文件格式,目标是单文件、可扩展、可 memory-map,并携带足够的 KV 元数据支持推理 runtime 直接加载。若部署路线包含 llama.cpp / Ollama 一类本地推理栈,训练产物通常需要在 Hugging Face 权重与 GGUF 之间做一次转换与量化。

Shell
1
2
# llama.cpp 仓库提供 convert_*.py 脚本把 Hugging Face 权重转换为 GGUF
python convert_hf_to_gguf.py --outfile out.gguf /path/to/hf_model_dir
数据管线与预处理组件

训练与推理系统的性能瓶颈经常不在模型前向,而在输入:数据从磁盘/对象存储进入进程、被解码与清洗、被分词与组 batch、再进入 GPU。一个可用的数据管线需要同时满足三件事:吞吐(喂满设备)、一致性(可复现、可回放)、可运维(能增量、能恢复、能追溯)。

本节把数据管线拆成四层:读取抽象(Dataset/DataLoader)、存储与后端(Arrow/Parquet/WebDataset/LMDB/HDF5/mmap)、文本入口(tokenizer 与中文预处理)、以及离线预处理模式(多进程 + shard 流式写入)。每一层都以“如何写代码把数据喂进训练/推理”为主线。

数据集抽象与读取方式
PyTorch 数据接口

PyTorch 的数据读取围绕两个 Dataset 协议展开:map-style(可随机访问)与 iterable-style(顺序流式)。DataLoader 负责把 Dataset 变成可迭代的 batch 流,并提供多进程 worker、prefetch、pin memory 等机制。实践中,Dataset 负责“怎样得到一个样本”,DataLoader 负责“怎样并行、怎样组批、怎样把样本送进设备”。

安装:

Shell
1
pip install torch
Dataset

map-style Dataset 的约束很简单:实现 __getitem__ 与 __len__。它适合“样本天然有索引”的存储,例如:一个样本一行的 Parquet/Arrow、固定条目 LMDB、按文件名索引的图像文件夹。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, paths):
        # map-style Dataset 通常先把“样本索引”准备好,真正读取留给 __getitem__
        self.paths = paths
 
    def __len__(self):
        # DataLoader 需要长度信息来做 epoch 边界、shuffle 和 sampler 计算
        return len(self.paths)
 
    def __getitem__(self, idx):
        # idx 只是索引入口,真正的样本内容可以来自文件、KV 或远端缓存
        path = self.paths[idx]
        with open(path, "rb") as f:
            # 这里读取原始二进制;解码可以放在这里,也可以延后到 collate_fn
            blob = f.read()
        # 返回 dict/tuple 都可以,关键是下游 collate_fn 知道怎么拼 batch
        return {"path": path, "blob": blob}
IterableDataset

IterableDataset 只要求实现 __iter__,更适合训练数据远大于本地磁盘、需要顺序扫描或在线生成的场景(对象存储流式、Kafka/队列、WebDataset tar 流、动态合成数据)。使用多 worker 时,必须自行做切分(worker shard),否则每个 worker 都会重复遍历同一份流。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch.utils.data import IterableDataset
 
class LineStream(IterableDataset):
    def __init__(self, filename):
        # iterable-style 数据集更像“流”,通常只保存数据源句柄或路径
        self.filename = filename
 
    def __iter__(self):
        # 多 worker 下需要知道自己是第几个 worker
        worker = torch.utils.data.get_worker_info()
        wid = 0 if worker is None else worker.id
        wnum = 1 if worker is None else worker.num_workers
 
        with open(self.filename, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                if (i % wnum) != wid:
                    continue  # 用取模切分行号,保证多个 worker 不会重复消费同一行
                # 逐条产出样本,避免一次性把整份文件读进内存
                yield {"text": line.rstrip("\n")}
DataLoader

DataLoader 的关键价值在于把“批处理策略”和“并行读取策略”显式化。它的典型参数包括: batch_size、 shuffle、 num_workers、 collate_fn、 pin_memory、 prefetch_factor、 persistent_workers。这些参数共同决定吞吐、延迟、内存占用与稳定性。

命令/API/函数
torch.utils.data.DataLoader

说明
把 Dataset/IterableDataset 变成 batch 流。关键参数: batch_size / num_workers / collate_fn / pin_memory

示例

Python
1
2
3
4
5
6
7
8
9
10
11
from torch.utils.data import DataLoader
 
loader = DataLoader(
    dataset,
    batch_size=64,           # 每次交给模型 64 条样本;它直接影响吞吐、显存占用和梯度噪声
    shuffle=True,            # 训练集通常打乱,避免数据顺序相关性影响收敛
    # 8 个 worker 并发准备样本;太小喂不满 GPU,太大又会放大 CPU/内存开销
    num_workers=8,
    pin_memory=True,         # batch 会先落在 page-locked 内存,拷到 GPU 时更快
    persistent_workers=True, # 跨 epoch 复用 worker 进程,减少频繁拉起带来的抖动
)

命令/API/函数
torch.utils.data.get_worker_info

说明
在 IterableDataset 中分片。关键参数:worker.id / worker.num_workers

示例

Python
1
2
3
worker = torch.utils.data.get_worker_info()
wid = 0 if worker is None else worker.id
wnum = 1 if worker is None else worker.num_workers
复现性:worker_init_fn + generator

DataLoader 的随机性并不只来自主进程的 torch.manual_seed。一旦打开多 worker,每个 worker 还会各自使用 Python、NumPy 和 PyTorch 的随机数源。官方文档给出的稳定做法是:主进程显式传入一个 torch.Generator,再在 worker_init_fn 里把 worker 级种子同步到 NumPy 和 Python 随机库。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import random
 
import numpy as np
import torch
from torch.utils.data import DataLoader
 
def seed_worker(worker_id):
    # PyTorch 已经给每个 worker 分配了独立种子;
    # 这里把它同步给 NumPy 和 Python random,避免三套 RNG 彼此漂移。
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
 
# 主进程自己的 DataLoader 随机源;
# shuffle、random_split、带 generator 的 sampler 都可以围绕它保持可复现。
g = torch.Generator()
g.manual_seed(3407)
 
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,                # 打乱顺序时最需要把随机源显式固定下来
    num_workers=8,
    worker_init_fn=seed_worker,  # 负责把 worker 内部的 NumPy/Python RNG 也种好
    generator=g,                 # 主进程侧的采样顺序由这份 generator 控制
)

分布式训练下,这套做法仍然成立,但还需要配合 DistributedSampler.set_epoch(epoch)。原因是多 rank 不只要“每次都随机”,还要“所有 rank 对同一轮 shuffle 的理解一致”。

collate_fn:把样本级对象真正拼成 batch

DataLoader 默认的拼 batch 逻辑只适合“每个样本都已经是规则张量”的情况。文本、语音、检测框这类变长任务里,真正决定 batch 结构的通常是自定义 collate_fn:padding 多长、保留哪些原始字段、哪些字段进 GPU、哪些字段只保留给评估与对齐。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch.nn.utils.rnn import pad_sequence
 
def collate_batch(samples):
    # 先把每条样本里真正要进模型的 token 序列取出来。
    input_ids = [torch.tensor(s["input_ids"], dtype=torch.long) for s in samples]
    labels = torch.tensor([s["label"] for s in samples], dtype=torch.long)
 
    # 变长序列在这里统一 pad;不要在 Dataset.__getitem__ 里把所有样本都 pad 到全局最大长度。
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    # attention_mask 应和 padding 规则同步生成,否则模型会把 pad token 也当成有效上下文。
    attention_mask = input_ids.ne(0).long()
 
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }
 
loader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=8,
    collate_fn=collate_batch,    # 真正定义“一个 batch 长什么样”
    pin_memory=True,
)

工程边界非常清楚:Dataset 负责拿到单条样本,collate_fn 负责把一批样本拼成模型输入。如果把 padding、随机 mask、目标对齐全部塞进 __getitem__,后续切 batch 策略、切 tokenizer 或切任务头时会很难维护。

Hugging Face Datasets

Datasets 把“数据集”抽象成 Arrow-backed 的 Dataset/ DatasetDict,并提供统一的加载(Hub/本地/通用 builder)、变换(map/filter)、以及落盘(save_to_disk / parquet)的工具链。它在大模型训练管线中的典型用法是:用一次离线 map 把清洗与分词做掉,输出可 memory-map 的 Arrow/Parquet,再用 PyTorch DataLoader 做高吞吐训练。

安装:

Shell
1
pip install datasets

加载:支持 Hub 数据集、目录内 CSV/JSON/Parquet 文件、以及通用 builder(例如 json / parquet / webdataset)。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from datasets import load_dataset
 
# 直接从 Hub 取公开数据集,适合快速验证预处理与训练脚本
ds = load_dataset("allenai/c4", "en", split="train")
 
ds = load_dataset(
    "parquet",
    data_files={"train": ["./data/train-*.parquet"]},
)["train"]  # 本地 parquet 目录是最常见的离线预处理产物之一
 
ds_stream = load_dataset(
    "json",
    data_files="s3://bucket/data.jsonl",
    streaming=True,
)["train"]  # streaming=True 不把全量数据落本地,适合超大规模语料或对象存储直读

命令/API/函数
datasets.load_dataset

说明
加载 Hub / 本地 / 通用 builder。关键参数: data_files / split / streaming / num_proc

示例

Python
1
2
3
4
5
6
ds = load_dataset(
    "json",
    data_files="train.jsonl",  # 告诉 builder 真正要读取哪份原始文件
    split="train",             # 直接拿 train split,省去再从 DatasetDict 里二次索引
    num_proc=8,                # 解析 JSONL 时开多进程,提高大文件导入速度
)

命令/API/函数
Dataset.map

说明
清洗/分词/特征工程。关键参数: batched / num_proc / remove_columns

示例

Python
1
2
3
4
5
6
7
8
9
10
11
def tok(batch):
    # map 阶段把原始文本转成 token id,训练时就不用每 step 再做字符串处理
    return tokenizer(batch["text"])
 
ds2 = ds.map(
    tok,
    batched=True,               # tokenizer 批量跑通常更快,也更接近真实训练吞吐
    num_proc=8,                 # 把 CPU 密集的分词并行化
    # 处理完就删掉原始文本列,减少后续数据体积和 batch 搬运成本
    remove_columns=["text"],
)

命令/API/函数
Dataset.save_to_disk

说明
保存 Arrow 数据集目录。关键参数:输出目录

示例

Python
1
ds2.save_to_disk("./out/ds_tok")

命令/API/函数
datasets.load_from_disk

说明
恢复已保存数据集。关键参数:输入目录

示例

Python
1
2
from datasets import load_from_disk
ds2 = load_from_disk("./out/ds_tok")
Features:把 schema 明确写出来

当数据不再只是“纯文本 + label”时,显式声明 Features 会比依赖自动推断更稳。它的价值在于:列类型固定、序列列有明确嵌套结构、图像/音频列知道该如何延迟解码,后续 cast 与格式转换也更可控。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from datasets import Audio, ClassLabel, Features, Sequence, Value
 
features = Features(
    {
        # 原始文本列;后续仍可保留用于 debug 或重跑 tokenizer
        "text": Value("string"),
        # 把离散标签显式收成受控枚举,避免字符串标签到处漂
        "label": ClassLabel(names=["neg", "pos"]),
        # token id 序列是“变长整型列表”,并非随手塞进 object 列
        "input_ids": Sequence(Value("int32")),
        # 音频列会在访问时自动解码/重采样到 16kHz
        "audio": Audio(sampling_rate=16000),
    }
)
with_format / set_format / with_transform:决定样本怎么交给训练框架

Datasets 的底层存储与“取一条样本时返回什么对象”是两件事。 with_format 返回一个带格式视图的新数据集, set_format 则原地修改; with_transform 更进一步,允许在取样时做惰性张量化或轻量预处理。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
ds_torch = ds.with_format(
    "torch",
    # 只把训练真正需要的列转成 tensor,减少无关列搬运
    columns=["input_ids", "attention_mask", "label"],
)
 
# 原地修改;适合临时实验,不适合在多人脚本里到处传同一个对象
ds.set_format("torch", columns=["label"])
 
def encode_on_the_fly(batch):
    # with_transform 适合“想保持原始文本,又不想提前全量落盘 token”的场景
    return tokenizer(batch["text"], truncation=True)
 
ds_lazy = ds.with_transform(encode_on_the_fly)
列操作与切分:remove / rename / cast / select / filter / shuffle

大多数真实数据清洗都离不开列级操作。这里的关键不仅 API 名字,还理解哪些操作在 schema 演进时最常见:删掉原始大字段减小数据体积、统一字段命名、把 label 从字符串 cast 到枚举、按稳定 seed 打乱并切分训练/验证集。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from datasets import ClassLabel, Value
 
# 先统一字段命名,后续 tokenizer/map 才不用为不同数据源写分支
ds = ds.rename_column("sentence", "text")
# 清洗后不再需要的大字段尽早删掉,减少 Arrow/Parquet 体积
ds = ds.remove_columns(["raw_html"])
# 把字符串或整数标签收敛成受控类别空间
ds = ds.cast_column("label", ClassLabel(names=["neg", "pos"]))
# 显式修正数值精度,避免跨库时 int32/int64 不一致
ds = ds.cast_column("idx", Value("int64"))
# 训练/验证切分前先固定随机种子,保证实验可复现
ds = ds.shuffle(seed=42)
# 把切分动作放进数据管线,避免散落在业务代码里
splits = ds.train_test_split(test_size=0.02, seed=42)
train_ds = splits["train"].select(range(100000))       # 只抽一个稳定子集做快速回归测试
# filter 更适合写成显式数据约束,训练循环只消费已清洗样本
train_ds = train_ds.filter(lambda x: len(x["text"]) > 0)
多源拼接与 cache/fingerprint 语义

预训练、SFT 和数据增强任务经常要把多份语料混到一起。Datasets 已提供 concatenate_datasets 与 interleave_datasets,前者是顺序拼接,后者更适合多语料轮转采样。另一件必须知道的事是 fingerprint:map/filter 的缓存是否复用,取决于数据内容、函数与参数共同生成的指纹,文件名相同不代表缓存一定可复用。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from datasets import concatenate_datasets, interleave_datasets
 
mix = concatenate_datasets([news_ds, qa_ds])           # 适合直接做“先 A 后 B”的顺序拼接
round_robin = interleave_datasets(
    # 更适合多源混训,让不同语料在样本流里交错出现
    [news_ds, qa_ds],
    stopping_strategy="all_exhausted",
)
 
tokenized = mix.map(
    tok,
    batched=True,
    # 默认会尽量复用缓存;函数或参数一变,fingerprint 也会变
    load_from_cache_file=True,
)

排查“为什么 map() 没有重跑”时,先看缓存目录与 fingerprint;排查“为什么无缘无故重跑了”时,也先看函数闭包、参数和依赖对象是否变了。

大规模数据读取后端

当单机磁盘与单个文件格式无法满足吞吐或并行度时,训练数据会落到“更工程化的后端”。最常见的三类:tar shard(WebDataset)、KV store(LMDB)、块存储/层次结构(HDF5)。它们的核心是“更匹配训练读取模式”。

WebDataset

WebDataset 以 tar shard 作为基本载体,强调流式读取与链式 pipeline。它常用于大规模图像/视频/音频/多模态训练:样本被打包成许多 tar 文件(shard),训练时按 shard 流式拉取、解码、组 batch。安装:

Shell
1
pip install webdataset

命令/API/函数
webdataset.WebDataset

说明
构建数据管线(DataPipeline + 流式操作)

示例

Python
1
2
3
4
5
6
7
8
9
10
import webdataset as wds
 
dataset = (
    # shard 模式让数据集按 tar 文件流式展开,避免训练前把全部样本解包到本地。
    wds.WebDataset("shards/data-{000000..000127}.tar")
      .shuffle(10000)
      .decode("pil")
      .to_tuple("jpg", "txt")
      .batched(64)
)

命令/API/函数
FluidInterface.with_epoch

说明
限制一个 epoch 的样本数(类似 islice)

示例

Python
1
dataset = dataset.with_epoch(1_000_000)
ShardWriter:把离线数据写成 tar shard

WebDataset 的写入侧通常比读取侧更值得标准化。常见约定是:同一个样本共享一份 __key__,不同模态或字段作为不同扩展名文件写进 tar。这让读取侧可以按 key 自动把图片、文本、JSON 元信息重新组回一条样本。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import io
import webdataset as wds
 
with wds.ShardWriter("shards/data-%06d.tar", maxcount=10000) as sink:
    for i, sample in enumerate(samples):
        sink.write(
            {
                "__key__": f"{i:08d}",       # 同一条样本的所有字段都靠这个 key 关联
                "txt": sample["text"],       # 文本字段直接写成 txt/json 之类扩展名
                "json": sample["meta"],      # 元数据可单独保留,方便检索和排错
                # 图像/音频/视频通常写原始字节,解码推迟到训练读取阶段
                "jpg": sample["image_bytes"]
            }
        )
分布式读取:split_by_node / split_by_worker

WebDataset 在多机多 worker 训练里最容易犯的错误是“每个 worker 都在读同一批 tar”。解决思路和 IterableDataset 类似:先按节点切 shard,再按 worker 切 shard。否则表面上吞吐很高,实际上样本被重复消费。

Python
1
2
3
4
5
6
7
dataset = (
    wds.WebDataset("shards/data-{000000..000127}.tar")
      .split_by_node()   # 先按节点分 shard,避免多机重复消费同一批 tar
      .split_by_worker() # 再按 DataLoader worker 进一步切分
      .decode("pil")
      .to_tuple("jpg", "txt")
)
LMDB

LMDB 是 memory-mapped 的 KV store,优势是读性能稳定、并发读友好,适合“样本就是 key->value”的训练数据(尤其是大量小对象)。LMDB 的关键约束是事务:读写必须在 transaction 内进行,并且从 LMDB 返回的值可能直接指向 mmap 区域,transaction 结束后不可继续使用该指针。

安装:

Shell
1
pip install lmdb

命令/API/函数
lmdb.open

说明
打开/创建环境(Environment)

示例

Python
1
2
3
4
5
6
7
8
import lmdb
env = lmdb.open(
    "./data.lmdb",
    map_size=1024**4,  # 预留 1TB 虚拟地址空间;LMDB 达到上限后需要手动扩容
    subdir=False,      # 把 ./data.lmdb 当成单文件模式的 LMDB 环境
    readonly=False,    # 写入场景必须关闭只读;离线建库完再切只读更稳
    lock=True,         # 多进程写入需要文件锁保证事务一致性
)

命令/API/函数
Environment.begin

说明
开启事务(Transaction)

示例

Python
1
2
with env.begin(write=True) as txn:
    txn.put(b"k", b"v")

命令/API/函数
Transaction.get / put

说明
读写 KV

示例

Python
1
2
with env.begin(write=False) as txn:
    v = txn.get(b"k")

命令/API/函数
Transaction.cursor

说明
迭代遍历

示例

Python
1
2
3
4
with env.begin() as txn:
    with txn.cursor() as cur:
        for k, v in cur:
            ...
HDF5

HDF5 适合块状数组与层次结构数据。训练里常见于科学计算数据、时序/医疗影像、以及需要 chunked 存储与压缩的场景。Python 侧最常用的是 h5py。安装:

Shell
1
pip install h5py

命令/API/函数
h5py.File

说明
打开/创建文件

示例

Python
1
2
import h5py
f = h5py.File("data.h5", "r")

命令/API/函数
Group.create_dataset

说明
创建 dataset(可设 chunks/compression)

示例

Python
1
2
3
4
5
6
7
d = f.create_dataset(
    "x",
    shape=(0, 4096),        # 初始为空;第一维表示样本数,后续按 batch 追加
    maxshape=(None, 4096),  # 第一维允许无限增长,否则 resize 会失败
    chunks=(1024, 4096),    # 以 1024 行为一个块,兼顾顺序写入与顺序读取
    dtype="int32",          # token id/离散特征常用整数类型,避免默认 int64 浪费空间
)

命令/API/函数
Dataset.resize

说明
追加写入(配合 maxshape)

示例

Python
1
2
3
n = d.shape[0]
d.resize((n + batch.shape[0], 4096))
d[n:] = batch

命令/API/函数
Group.require_dataset

说明
“如果不存在就创建,存在就复用”的 dataset 入口。长流程作业做增量写入或多阶段预处理时,比手写存在性判断更稳。

示例

Python
1
2
3
4
5
6
7
tokens = f.require_dataset(
    "tokens",
    shape=(0, 4096),
    maxshape=(None, 4096),
    chunks=(1024, 4096),
    dtype="int32",
)

命令/API/函数
Dataset.asstr / h5py.string_dtype

说明
把字符串 dataset 明确当作文本处理,并统一文本元数据、路径和标签名的字符串语义,减少跨 Python 版本和跨平台读取问题。

示例

Python
1
2
3
4
5
6
7
meta = f.create_dataset(
    "doc_id",
    data=["a", "b", "c"],
    dtype=h5py.string_dtype(encoding="utf-8"),
)
 
doc_ids = meta.asstr()[:]
内存映射

memory map 的价值在于把“磁盘 IO + 反序列化”变成“按页缺页加载”:进程只在访问到某段数据时才触发读取,并允许多个进程共享同一份文件缓存。它常用于 Arrow IPC、NumPy 的大数组、以及只读数据集的多 worker 读取。

Python
1
2
3
4
5
import numpy as np
 
# 以 memmap 读一个巨大 float32 数组(示例)
arr = np.memmap("x.bin", dtype=np.float32, mode="r")
# arr[i] 的访问才会触发对应页的加载
MosaicML Streaming

MosaicML Streaming 把“超大语料存放在对象存储上、训练时按需拉取”做成专用库。它的定位和 WebDataset 有交集,但更强调多节点训练时的正确性、确定性与 just-in-time 混合采样。对于“数据不想整份预拉到本地 NVMe”的预训练任务,它是值得单独了解的一条路线。

Shell
1
pip install mosaicml-streaming

Python
1
2
3
4
5
6
7
8
9
from streaming import StreamingDataset
 
dataset = StreamingDataset(
    remote="s3://bucket/my-corpus",  # 远端对象存储是真正的数据源,训练时按 shard 增量拉取
    local="/tmp/streaming-cache",    # 本地目录只做工作集缓存,不保存全量镜像
    shuffle=True,                    # 把远端 shard 流按训练需要做确定性 shuffle
    # 某些 streaming 路线会把 batch 语义前移到数据层,便于控制采样顺序
    batch_size=8,
)
文本分词与 tokenizer 组件

分词(Tokenization)有两种工程形态:在线分词(推理时对用户输入分词)与离线分词(训练前把语料转成 token id)。离线分词的目标是把训练阶段的 CPU 开销外移:训练时直接读取 input_ids/ attention_mask 之类张量,避免每步都做字符串处理。

tokenizers

Tokenizers 是 Rust 实现的 tokenizer 库,提供训练、编码、解码以及 padding/truncation 等预处理步骤。它面向生产:同一套 tokenizer 可以被训练脚本、离线预处理作业与线上服务复用。

安装:

Shell
1
pip install tokenizers

命令/API/函数
tokenizers.Tokenizer

说明
tokenizer 管线对象

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
 
tok = Tokenizer(BPE(unk_token="[UNK]"))  # 先定义底层分词模型;未知词会回退到 [UNK]
tok.pre_tokenizer = Whitespace()         # 先按空白切粗粒度片段,再在片段内学习 BPE merge
trainer = BpeTrainer(
    vocab_size=32000,  # 词表大小直接影响 embedding 尺寸、OOV 粒度和序列长度
    # 训练阶段先把特殊 token 固定进词表,避免后面再补导致 id 漂移
    special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"],
)
# 从语料文件训练 tokenizer;产出的 merge 规则和 vocab 可供训练/推理共用
tok.train(["corpus.txt"], trainer)

命令/API/函数
Tokenizer.encode / encode_batch

说明
编码为 token id

示例

Python
1
2
3
4
5
enc = tok.encode("hello world")
ids = enc.ids
 
encs = tok.encode_batch(["a", "b"])
batch_ids = [e.ids for e in encs]

命令/API/函数
Tokenizer.decode / decode_batch

说明
解码

示例

Python
1
2
text = tok.decode(ids)
texts = tok.decode_batch(batch_ids)
tokenizer pipeline:normalizer / pre-tokenizer / post-processor

现代 tokenizer 并非“一个黑盒 encode()”。它通常由四段组成:normalizer 负责文本标准化,pre-tokenizer 负责粗切分,model 负责真正的子词编码,post-processor 负责补特殊 token 与 sequence pair 结构。把这条 pipeline 拆开理解,排查 token 边界、special token 或 pair 输入错乱时会快很多。

Python
1
2
3
4
5
6
7
8
9
10
11
12
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import TemplateProcessing
 
tok.normalizer = NFKC()  # 先做 Unicode 规范化,减少全角/兼容字符造成的词表噪声
tok.pre_tokenizer = ByteLevel()  # ByteLevel 适合需要 byte-level 稳定覆盖的 tokenizer 路线
tok.post_processor = TemplateProcessing(
    # 单句输入如何补特殊 token,由 post-processor 决定
    single="[BOS] $A [EOS]",
    pair="[BOS] $A [EOS] $B:1 [EOS]:1",    # :1 表示 sequence B 的 type_id=1
    special_tokens=[("[BOS]", 1), ("[EOS]", 2)],
)
Encoding.offsets:span 对齐与可解释性排查

做 NER、高亮、引用定位、chunk 边界回溯时,单看 token id 不够,还需要字符级对齐信息。 Encoding.offsets 与相关映射方法,正是把“token 第几个”重新映射回“原文本的哪一段”。

Python
1
2
3
enc = tok.encode("Hello, how are you?")
start, end = enc.offsets[2]  # offsets[i] 给出第 i 个 token 在原始字符串中的字符区间
span = "Hello, how are you?"[start:end]
SentencePiece

SentencePiece 既是一套 tokenizer 算法(BPE / Unigram LM),也是训练与推理工具链。它的工程优势在于“直接在原始文本上训练”,不依赖预分词,对无空格语言更友好。

安装:

Shell
1
pip install sentencepiece

命令/API/函数
sentencepiece.SentencePieceTrainer.Train

说明
训练 spm 模型

示例

Python
1
2
3
4
5
6
7
8
9
import sentencepiece as spm
 
spm.SentencePieceTrainer.Train(
    input="corpus.txt",          # 原始训练语料;SentencePiece 直接在未预分词文本上学习
    model_prefix="spm",          # 会生成 spm.model 和 spm.vocab 两个文件
    vocab_size=32000,            # 词表大小决定分词粒度与 embedding 尺寸
    model_type="bpe",            # 这里选 BPE;也可换成 unigram,适合不同语言与语料分布
    character_coverage=0.9995,   # 尽量覆盖高频字符;中文/日文场景通常会把这个值设得较高
)

命令/API/函数
sentencepiece.SentencePieceProcessor

说明
加载并编码/解码

示例

Python
1
2
3
4
5
6
7
8
9
import sentencepiece as spm
 
# Processor 负责加载训练好的 .model,并提供 encode/decode 接口
sp = spm.SentencePieceProcessor()
sp.load("spm.model")
# out_type=int 直接返回 token id,方便接训练或推理张量化流程
ids = sp.encode("你好世界", out_type=int)
# decode 用于调试分词质量、还原输出或做可读性检查
text = sp.decode(ids)
id、piece 与采样

SentencePiece 的工程排查经常落在两个问题上:某个 token id 到底对应什么 piece,以及训练或数据增强时是否要启用采样。SentencePiece 用 ▁ 表示词边界,这一点在日志或调试输出里经常会出现。

Python
1
2
3
4
5
6
7
8
9
10
11
12
piece = sp.id_to_piece(42)        # 直接查看某个 id 对应的 piece,定位奇怪 token 时很常用
pid = sp.piece_to_id("▁hello")    # ▁ 表示词边界;这类 piece 在英文模型里很常见
bos = sp.bos_id()                 # 特殊 token id 应在训练前就确认,避免和模型配置不一致
eos = sp.eos_id()
 
sampled = sp.encode(
    "你好世界",
    out_type=int,
    enable_sampling=True,         # unigram 路线常用采样做子词正则化;推理阶段通常关闭
    nbest_size=-1,
    alpha=0.1,
)
tiktoken

tiktoken 是 OpenAI 开源的 BPE tokenizer 实现,常用于与 OpenAI 模型兼容的 token 计数与编码。它提供按 encoding 名称或按模型名选择 encoding 的接口。

安装:

Shell
1
pip install tiktoken

命令/API/函数
tiktoken.get_encoding

说明
按 encoding 名称获取

示例

Python
1
2
import tiktoken
enc = tiktoken.get_encoding("o200k_base")

命令/API/函数
tiktoken.encoding_for_model

说明
按模型名获取

示例

Python
1
enc = tiktoken.encoding_for_model("gpt-4o")

命令/API/函数
Encoding.encode / decode

说明
编码/解码

示例

Python
1
2
ids = enc.encode("hello world")
text = enc.decode(ids)
special token 策略与聊天计数

tiktoken 在工程里最常用的核心是做 token 预算与费用估算。这里最容易踩的坑是 special token:有的调用希望严格禁止特殊 token 混进普通文本,有的调用则明确允许它们出现。

Python
1
2
3
4
5
6
7
ids = enc.encode_ordinary("hello world")  # 只按普通文本编码,不去识别特殊 token 片段
 
ids = enc.encode(
    "<|endoftext|>hello",
    # 只有显式允许的特殊 token 才会被当成特殊符号处理
    allowed_special={"<|endoftext|>"},
)

聊天计数时,真正计费的通常是模板化后的整段输入。因此更稳的做法是:先按目标 SDK/服务端的消息模板把 system/user/tool 消息串成最终文本,再统一送进 tokenizer 计数,避免逐条消息单独估算后相加。

中文文本预处理与分词工具

在大模型训练中,中文通常直接走子词/字节级 tokenizer;但在传统 NLP、搜索、实体抽取、以及“数据清洗与规范化”阶段,中文分词与繁简转换仍是高频工程环节。

jieba

安装:

Shell
1
pip install jieba

命令/API/函数
jieba.cut

说明
分词(generator)

示例

Python
1
2
import jieba
tokens = list(jieba.cut("我爱自然语言处理"))

命令/API/函数
jieba.lcut

说明
分词(list)

示例

Python
1
tokens = jieba.lcut("我爱自然语言处理")

命令/API/函数
jieba.cut_for_search

说明
搜索引擎模式(更细粒度)

示例

Python
1
tokens = list(jieba.cut_for_search("南京市长江大桥"))

命令/API/函数
jieba.add_word

说明
动态加入词典

示例

Python
1
jieba.add_word("大语言模型")

命令/API/函数
jieba.load_userdict

说明
加载用户词典

示例

Python
1
jieba.load_userdict("userdict.txt")
opencc-python

opencc-python 是早期 OpenCC 的 Python wrapper,版本较旧。现代工程更常用维护更活跃的 OpenCC 包。这里保留两条安装路径:兼容旧 wrapper 与直接使用 OpenCC。

Shell
1
2
3
4
5
# 旧 wrapper(较旧)
pip install opencc-python
 
# 推荐:OpenCC(维护更活跃)
pip install OpenCC

Python
1
2
3
4
# OpenCC 示例:繁转简
from opencc import OpenCC
cc = OpenCC("t2s")
out = cc.convert("今天天氣不錯")
spaCy:NLP 管线与结构化抽取框架

spaCy 是面向生产的 NLP pipeline 框架。它在训练与推理工程中的价值,是把原始文本稳定转成带 token、span、实体、句子、词性、依存和分类结果的结构化 Doc。离线预处理、实体抽取、搜索字段加工、标注数据转换、规则兜底和小中型 NLP 任务训练,都是它的高频位置。

它和 Hugging Face tokenizer、LLM 推理引擎的职责不同。tokenizer 负责把文本切成模型输入 token;vLLM/Transformers 负责生成或模型前向;spaCy 更像一条可配置的文本加工管线,把 NLP 注释统一挂在 Doc、 Token 和 Span 上,方便下游业务逻辑消费。

安装与 QuickStart
Shell
1
2
3
4
5
6
7
8
9
# 安装 spaCy 主包。
pip install -U spacy
 
# 下载英文小型 trained pipeline,包含 tokenizer、tagger、parser、ner 等组件。
python -m spacy download en_core_web_sm
 
# 需要 transformer pipeline 时安装对应 extra,再下载 _trf pipeline。
pip install -U "spacy[transformers]"
python -m spacy download en_core_web_trf

Python
1
2
3
4
5
6
7
8
9
10
11
import spacy
 
# nlp 是 Language pipeline 对象,生产服务中通常每个进程加载一次并复用。
nlp = spacy.load("en_core_web_sm")
 
# 调用 nlp(text) 会先 tokenize,再按 pipeline 顺序运行组件。
doc = nlp("Apple is looking at buying U.K. startup for $1 billion")
 
# doc.ents 是实体 Span 序列;每个 Span 保留原文区间和实体标签。
for ent in doc.ents:
    print(ent.text, ent.label_, ent.start_char, ent.end_char)
对象模型:nlp / Doc / Token / Span
对象 工程含义 常用入口
nlp Language 管线对象,持有 tokenizer、共享词表、语言数据、组件、权重和配置。 spacy.load、 spacy.blank、 nlp.pipe
Doc 整段文本与注释的容器,保留 token 序列、实体、句子、分类结果和原文对齐。 doc.ents、 doc.sents、 doc.cats
Token 单个 token 的视图,持有词性、依存、实体 IOB、lemma、向量等属性。 token.text、 token.pos_、 token.ent_type_
Span 连续 token 区间,NER 实体、句子、规则匹配结果和候选短语通常都用它表示。 doc[start:end]、 span.label_、 span.start_char
Python
1
2
3
4
5
6
7
8
9
10
11
12
# 空白语言对象只提供 tokenizer 和语言规则,不包含预训练统计组件。
nlp = spacy.blank("en")
 
# make_doc 只做 tokenization,适合构造训练数据或做规则前处理。
doc = nlp.make_doc("Only tokenization runs here.")
 
# Token 是 Doc 上的视图;Span 是连续 token 区间。
first_token = doc[0]
first_span = doc[0:2]
 
print(first_token.text)
print(first_span.text)
Pipeline components

spaCy pipeline 的组件按顺序接收并返回 Doc。常见内置组件包括 tok2vec、 transformer、 tagger、 parser、 ner、 entity_ruler、 entity_linker、 textcat、 sentencizer 和 lemmatizer。

批量处理用 nlp.pipe。推理时只加载和运行需要的组件,能明显降低 CPU/GPU 开销。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import spacy
 
texts = [
    "Net income was $9.4 million.",
    "Revenue exceeded twelve billion dollars.",
]
 
# exclude 表示组件不加载进内存,适合明确用不到 parser/lemmatizer 的抽取任务。
nlp = spacy.load("en_core_web_sm", exclude=["parser", "lemmatizer"])
 
# nlp.pipe 会把多条文本批处理,比 Python 循环逐条 nlp(text) 更适合离线预处理。
for doc in nlp.pipe(texts, batch_size=128):
    # 这里只消费 NER 结果,避免运行无关组件。
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    print(entities)
规则组件与结构化抽取

entity_ruler、 span_ruler 和 matcher 系列组件适合把业务词典、正则模式、产品名、地名、合规术语写成可复现规则。它们常用于冷启动 NER、补充统计模型漏召回、构造弱标注数据和搜索字段加工。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import spacy
 
nlp = spacy.blank("en")
 
# entity_ruler 会把规则命中的文本写入 doc.ents。
# 放在训练型 ner 前后会影响覆盖关系,生产中要固定 pipeline 顺序。
ruler = nlp.add_pipe("entity_ruler")
 
# patterns 是可版本化的业务词典,比散落在代码里的 if/regex 更容易审计。
ruler.add_patterns([
    {"label": "ORG", "pattern": "OpenAI"},
    {"label": "PRODUCT", "pattern": "ChatGPT"},
])
 
doc = nlp("OpenAI released ChatGPT.")
print([(ent.text, ent.label_) for ent in doc.ents])
训练与 config.cfg

spaCy v3 的训练以 config.cfg 为中心。配置文件定义语言、tokenizer、pipeline components、模型结构、路径、初始化资源、训练循环和优化器。训练产物会携带最终配置,便于复现实验。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 先生成基础配置,再填充默认值,避免手写缺失字段。
python -m spacy init config base_config.cfg --lang en --pipeline ner --optimize efficiency
python -m spacy init fill-config base_config.cfg config.cfg
 
# debug data 在正式训练前检查标签、切分、实体边界和数据格式。
python -m spacy debug data config.cfg
 
# 训练数据通常使用 .spacy 格式,内部由 DocBin 保存带注释的 Doc。
python -m spacy train config.cfg \
  --output ./output \
  --paths.train ./train.spacy \
  --paths.dev ./dev.spacy
 
# GPU 训练通过 --gpu-id 指定设备。
python -m spacy train config.cfg --gpu-id 0

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import spacy
from spacy.tokens import DocBin
 
nlp = spacy.blank("en")
doc_bin = DocBin()
 
# 示例数据使用字符级实体边界。
# 真实项目应在转换阶段检查 char_span 是否为空,避免 tokenizer 边界不一致。
text = "Apple bought a startup in London."
doc = nlp.make_doc(text)
span = doc.char_span(0, 5, label="ORG")
 
if span is None:
    raise ValueError("Entity boundary does not align with tokenizer output")
 
doc.ents = [span]
doc_bin.add(doc)
 
# .spacy 文件是 spaCy 训练命令的常用输入格式。
doc_bin.to_disk("train.spacy")
中文、多语言与 Transformers

中文 pipeline 的关键是 tokenizer 配置。spaCy 中文语言类支持字符级分词、jieba 和 pkuseg 路线;训练、评测和线上推理必须使用同一 tokenizer 设置,否则实体边界和 span 对齐会漂移。多语言或语言中立 pipeline 通常使用 xx 语言 ID。

spaCy 的 transformer 支持以 transformer 组件进入 pipeline,为 ner、 textcat 等组件提供上下文表示,并把输出写入 Doc._.trf_data。这条路线提高准确率,但训练和推理成本更高。

Python
1
2
3
4
5
6
7
8
9
import spacy
 
# transformer pipeline 通常以 _trf 结尾,内部包含 transformer component。
nlp = spacy.load("en_core_web_trf")
doc = nlp("This sentence is processed with a transformer-backed pipeline.")
 
# transformer 输出挂在 Doc 扩展字段上,下游组件共享这份上下文表示。
trf_data = doc._.trf_data
print(type(trf_data))
NER 工程边界

spaCy 的默认 ner 组件适合 flat NER:实体是非重叠 labelled spans,结果写入 Doc.ents 和 token 的实体属性。嵌套实体、多标签 span、需要大量候选 span 打分的任务,更适合后文的 GlobalPointer、GLiNER 或自定义 span 分类路线。

批处理构造与样本拼接

batch 构造的关键目标是把“变长样本”转成“规则张量”,并尽量减少无效计算。四个高频机制:collator(怎么把样本列表变成 batch)、padding(对齐长度)、packing(把多个短样本拼到一个长序列里)、masking(构造 loss 的可学习位置)。

batch 构造机制
collator

collator 通常以 collate_fn 形式接入 DataLoader,负责把 List[sample] 转为张量 batch。

padding

padding 的核心是把不同长度序列补齐,并同步产生 attention_mask。对于 encoder 任务,padding 的位置通常 mask 掉注意力;对于 decoder 任务,还要考虑因果 mask 与 label mask。

packing

packing 把多个短序列拼接到固定长度 block 中,减少 padding 浪费。它适合预训练与指令微调中的“很多短样本”场景,但需要正确构造 label 与分段边界(例如用 special token 分隔)。

masking

masking 用来指定 loss 只在哪些位置计算,例如 causal LM 的 label shift、MLM 的随机 mask、SFT 中把提示词部分的 label 设为 ignore。

多模态样本组织与 processor

多模态模型往往通过 processor 把文本 tokenizer 与视觉/音频预处理封装在一起。工程上需要保证:离线预处理与线上推理使用同一套 processor 配置,避免“训练时的输入分布”和“推理时的输入分布”不一致。

合成数据与数据增强工具

合成数据通常用来补齐边界覆盖,并不替代真实数据:格式多样性、语言多样性、脏数据模式、罕见实体组合、以及隐私合规场景下的脱敏替身。合成数据要能回放:生成种子、版本、配置都要进入产物元数据。

结构化合成数据生成
Faker

安装:

Shell
1
pip install Faker
names-dataset

安装:

Shell
1
pip install names-dataset
pycountry

安装:

Shell
1
pip install pycountry

命令/API/函数
Faker

说明
Faker / Faker.seed

示例

Python
1
2
3
4
from faker import Faker
Faker.seed(42)
fake = Faker(locale="zh_CN")
row = {"name": fake.name(), "addr": fake.address()}

命令/API/函数
names-dataset

说明
NameDataset

示例

Python
1
2
3
from names_dataset import NameDataset
nd = NameDataset()
info = nd.search("Zoe")

命令/API/函数
pycountry

说明
pycountry.countries / lookup

示例

Python
1
2
3
import pycountry
cn = pycountry.countries.lookup("China")
langs = pycountry.languages.get(alpha_2="zh")
语言识别与类型检测库

语言识别与类型检测经常用于预处理阶段的路由:多语言混杂数据的分桶、代码/文档/日志的分流、以及不同清洗规则的选择。工程重点是“低成本、可解释、可复现”,检测逻辑应服务数据路由和清洗策略。

自然语言识别

命令/API/函数
langdetect: detect / detect_langs / DetectorFactory.seed

说明

安装方式如下。

Shell
1
pip install langdetect

示例

Python
1
2
3
4
from langdetect import detect, detect_langs, DetectorFactory
DetectorFactory.seed = 0
lang = detect("Hello world")
langs = detect_langs("Otec matka syn.")

命令/API/函数
fastText lid: fasttext.load_model / model.predict

说明

安装方式如下。

Shell
1
pip install fasttext

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
import fasttext
from huggingface_hub import hf_hub_download
 
model_path = hf_hub_download(
    # 直接从 Hub 下载语言识别模型,不必手工管理二进制文件
    repo_id="facebook/fasttext-language-identification",
    filename="model.bin",
)
# fastText 的 lid 模型加载后即可直接做短文本语言预测
model = fasttext.load_model(model_path)
# 返回标签和置信度,适合做数据清洗前的语言分桶
labels, probs = model.predict("Hello, world!")

命令/API/函数
lingua-language-detector: LanguageDetectorBuilder

说明

安装方式如下。

Shell
1
pip install lingua-language-detector

示例

Python
1
2
3
from lingua import LanguageDetectorBuilder
detector = LanguageDetectorBuilder.from_all_languages().build()
lang = detector.detect_language_of("Hello world")
文件类型检测

文件类型检测的目标是“尽量早发现二进制/压缩/不支持格式”,避免把不可解析内容送进后续清洗链路。python-magic 依赖底层 libmagic,需要系统依赖到位。

命令/API/函数
python-magic: magic.from_file / magic.from_buffer / magic.Magic

说明

安装方式如下。

Shell
1
2
3
pip install python-magic
# Debian/Ubuntu:
sudo apt-get install libmagic1

示例

Python
1
2
3
import magic
mime = magic.from_file("a.pdf", mime=True)
sig = magic.from_buffer(open("a.pdf", "rb").read(2048))
源码语言检测

源码语言检测用于代码数据集清洗(按语言分桶、去除 vendored/generated)与语法级预处理(解析 AST、提取符号)。Pygments 偏启发式 lexer;tree-sitter 提供结构化解析(AST)。

命令/API/函数
Pygments: get_lexer_for_filename / get_lexer_by_name

说明

安装方式如下。

Shell
1
pip install pygments

示例

Python
1
2
from pygments.lexers import get_lexer_for_filename
lexer = get_lexer_for_filename("a.py")

命令/API/函数
tree-sitter: Language / Parser

说明

安装方式如下。

Shell
1
pip install tree-sitter tree-sitter-python

示例

Python
1
2
3
4
5
6
7
8
from tree_sitter import Language, Parser
import tree_sitter_python as tspython
 
# 把 Python 语法定义编译成 tree-sitter 可消费的 Language 对象
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)                 # parser 会按这套语法把源码切成 AST
# parse 需要字节串输入,产物可继续拿去做节点遍历和结构化抽取
tree = parser.parse(b"print('hi')\n")

“Linguist 类方案”通常指 GitHub Linguist 的启发式:文件扩展名 + shebang + 内容特征 + 语言冲突消歧规则。实际工程里常把这类规则作为“分桶的第一步”,再对少量不确定样本做更重的解析。

大规模离线预处理模式

离线预处理的目标是把昂贵的 CPU 工作(解析、清洗、分词、格式规范化)集中到一次批处理作业中,并把结果写成稳定、可复用、可 memory-map 的 shard。一个可运维的离线预处理作业至少要具备:可重跑、可断点续跑、单 shard 失败不影响全局、输出具备 manifest。

多进程 Pool

多进程的核心收益在于绕过 Python GIL,把 CPU 密集的分词与解析并行化。进程间传输大对象会迅速放大序列化开销,因此更稳妥的做法是把输入切成可流式读取的小对象,把输出写成 shard。

shard 流式写入

shard 是离线预处理的基本单元:每个 shard 控制大小(例如 512MB~2GB)、可单独校验、可单独重跑。WebDataset 的 shard 是 tar;HF Datasets/Arrow 的 shard 是 parquet/arrow 文件集合;LMDB/HDF5 则是数据库/容器文件。

内存安全与失败恢复
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
import os
from multiprocessing import Pool
 
def process_line(line: str) -> str:
    obj = json.loads(line)              # 每个 worker 只处理一行,避免跨进程传大对象
    obj["text"] = obj["text"].strip()   # 这里放清洗、规范化或轻量分词逻辑
    # 重新写回 JSONL,保持下游训练管线最常见的输入格式
    return json.dumps(obj, ensure_ascii=False)
 
def write_shards(in_path: str, out_dir: str, shard_lines: int = 200_000, workers: int = 8):
    os.makedirs(out_dir, exist_ok=True)  # 先保证输出目录存在,方便失败后重跑
 
    def shard_path(i: int) -> str:
        # shard 编号固定宽度,后续排序和恢复更稳定
        return os.path.join(out_dir, f"shard-{i:06d}.jsonl")
 
    with open(in_path, "r", encoding="utf-8") as f, Pool(processes=workers) as pool:
        shard_idx = 0
        buf = []  # 先在内存里累计一批处理结果,再整 shard 落盘,减少小文件写入抖动
        for out in pool.imap(process_line, f, chunksize=256):
            buf.append(out)
            if len(buf) >= shard_lines:
                tmp = shard_path(shard_idx) + ".tmp"
                with open(tmp, "w", encoding="utf-8") as wf:
                    wf.write("\n".join(buf) + "\n")
                # 先写 .tmp 再原子替换,避免中途中断留下半成品
                os.replace(tmp, shard_path(shard_idx))
                buf.clear()
                shard_idx += 1
 
        if buf:
            tmp = shard_path(shard_idx) + ".tmp"
            with open(tmp, "w", encoding="utf-8") as wf:
                wf.write("\n".join(buf) + "\n")
            os.replace(tmp, shard_path(shard_idx))
基础训练框架

基础训练框架(Foundational Training Framework)提供三类不可替代的底座能力:张量与设备执行(Tensor & Device Execution)、自动求导(Automatic Differentiation)、以及训练循环所需的基础组件(模块、优化器、数据加载、序列化)。上层训练框架可以封装流程,但底层的梯度、显存与 kernel 行为最终仍由这一层决定。

PyTorch

PyTorch 的编程模型是“Python 先行的动态图(Dynamic Graph)+ 胶带式自动求导(Tape-based Autograd)”。这使训练循环天然可调试:前向是普通 Python 代码,反向由 autograd 记录并回放。工程上更重要的是:PyTorch 生态已经把训练、分布式、编译优化与部署衔接做成了一套可组合部件,既能写研究型循环,也能写生产训练栈。

安装建议遵循官方安装页的选择器:CPU-only 与 CUDA 版本的 pip/conda 命令需要与目标机器驱动、CUDA 版本匹配。CPU 环境通常可以直接 pip install torch;GPU 环境应按官方给出的 index-url 安装对应 CUDA wheel。安装后验证建议至少覆盖两点:能创建张量以及GPU 可见(如适用)。

PyTorch installation (pattern)
Shell
1
2
3
4
5
6
7
# CPU (typical)
# CPU 机器直接安装官方 wheel 即可,不需要额外的 CUDA index-url。
pip install -U torch
 
# CUDA: use the command generated by https://pytorch.org/get-started/locally/
# It typically looks like:
# pip install -U torch --index-url https://download.pytorch.org/whl/cu12x

PyTorch quick verify
Python
1
2
3
4
5
import torch
print(torch.__version__)
x = torch.randn(2, 3)
print(x.shape)
print("cuda_available:", torch.cuda.is_available())
核心抽象

三件事决定 PyTorch 训练代码的结构:张量(Tensor)承载数据与参数、autograd 负责梯度、 nn.Module 组织可训练子图并提供参数与缓冲区的可追踪结构。

Tensor

Tensor 是所有计算的基本载体。训练相关的关键元信息有四类:形状(shape)、数值类型(dtype)、设备(device)与梯度开关(requires_grad)。其中设备与 dtype 会直接改变 kernel 路径与显存占用;requires_grad 决定该张量是否会成为计算图的一部分。

命令/API/函数
torch.tensor

说明
从 Python 对象构造张量

示例

Python
1
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)

命令/API/函数
torch.randn

说明
随机初始化(参数/输入常用)

示例

Python
1
w = torch.randn(768, 768, device="cuda")

命令/API/函数
Tensor.to

说明
设备/精度迁移

示例

Python
1
x = x.to(device="cuda", dtype=torch.bfloat16)

命令/API/函数
Tensor.requires_grad_

说明
把张量标为需要梯度

示例

Python
1
x = x.requires_grad_(True)

命令/API/函数
torch.no_grad

说明
推理/评估时关闭梯度跟踪

示例

Python
1
2
with torch.no_grad():
    y = model(x)

命令/API/函数
torch.inference_mode

说明
比 no_grad 更强的推理模式(更少开销)

示例

Python
1
2
with torch.inference_mode():
    y = model(x)
autograd

autograd 的编程要点是“图从前向构建,梯度在反向回传”。多数训练代码只用到 loss.backward() + optimizer.step(),但当需要更复杂的梯度形态(如多目标、梯度惩罚、二阶项)时,就会显式使用 torch.autograd.grad。

命令/API/函数
Tensor.backward

说明
反向传播,累计梯度到 leaf 参数

示例

Python
1
2
3
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

命令/API/函数
torch.autograd.grad

说明
函数式求梯度,返回梯度张量而不写入 .grad

示例

Python
1
grads = torch.autograd.grad(loss, model.parameters(), create_graph=False)

命令/API/函数
torch.autograd.Function

说明
自定义前向/反向(自定义算子或特殊梯度)

示例

Python
1
2
3
4
5
6
# 定义可复用的类。
class MyFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x): ...
    @staticmethod
    def backward(ctx, grad_out): ...
nn.Module

nn.Module 把“可训练参数 + 前向逻辑 + 子模块拓扑”打包成可组合单元。训练与部署的关键接口是 state_dict():它让参数与 buffer 的保存/加载成为稳定约定,从而把“代码结构”与“权重文件”解耦。

Minimal nn.Module
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn
 
class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        # 顺序堆叠层能把“线性层 + 激活 + 线性层”打包成一个可保存子模块。
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, out_dim),
        )
 
    def forward(self, x):
        return self.net(x)
执行与编译

PyTorch 的默认执行是 eager;编译与优化通过 torch.compile 把前向(以及可捕获的反向)分段转换成可优化子图,再由后端生成更高效的执行计划。实践中它最适合用于“训练循环稳定、算子形态固定”的热点路径;高度动态的 Python 控制流与形状多态会降低捕获与复用效果。

动态图

动态图让训练循环具备“逐步可观察性”:每一步前向的张量都可以被打印、断点或插桩;复杂条件分支也能自然表达。这种表达力的代价是:若不做编译或算子融合,小算子密集的模型可能被 Python 调度开销限制。

训练循环

训练循环的工程要点集中在三个地方:梯度清零策略、设备放置与数据搬运、以及异常时可恢复的 checkpoint。下面是一段最小可用的循环骨架。

Minimal training loop (PyTorch)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch.utils.data import DataLoader
 
# 训练开始前先确定 device;后面模型、输入和标签都要落到同一设备上。
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MLP(128, 256, 10).to(device)
# 优化器和损失函数在循环外初始化,避免每个 step 重建对象。
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()
 
# DataLoader 负责批处理、shuffle 和 pinned memory。
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
 
# 切到训练模式后,dropout / batchnorm 等模块会走训练态分支。
model.train()
for step, (x, y) in enumerate(loader):
    # non_blocking=True 和 pin_memory 配合使用时,CPU->GPU 拷贝延迟更低。
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
 
    # 标准训练 step:清梯度、前向、算 loss、反向、更新参数。
    opt.zero_grad(set_to_none=True)
    logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    opt.step()
torch.compile

最常见的接入方式是在训练开始前包一层: model = torch.compile(model)。在大模型场景下,编译往往与 AMP、FlashAttention、FSDP/ZeRO 等一起协作:编译负责降低算子调度与 kernel 生成开销,其它组件负责显存、带宽与通信。

torch.compile (typical)
Python
1
2
model = MLP(128, 256, 10).to(device)
model = torch.compile(model)  # compile after moving to device
训练基础设施

训练基础设施是把“能跑”变成“可长期迭代”的关键层:数据加载决定吞吐上限,AMP 决定显存与速度,checkpoint 决定可恢复性;而 CRF layer 代表一类“训练框架之外的结构化输出层”组件,常见于序列标注。

DataLoader

DataLoader 是 PyTorch 数据管线的核心抽象:负责 batch、shuffle、collate、多进程加载与 pinned memory。它与 Dataset 的分工边界清晰:Dataset 负责“样本怎样被索引/生成”,DataLoader 负责“样本如何被并发读取并组织成 batch”。

参数 作用 经验用法
batch_size 每步样本数 受显存与序列长度影响,常与 gradient accumulation 联动
num_workers 数据加载进程数 CPU 预处理重时提高;过高会导致调度/内存压力
pin_memory 将 batch 放入 pinned memory GPU 训练通常打开,配合 non_blocking=True
collate_fn 自定义 batch 拼接 NLP 里常做 padding/packing;多模态里做对齐与打包
AMP

自动混合精度(Automatic Mixed Precision, AMP)通过在合适的算子上使用 FP16/BF16,并在数值敏感的算子上保留 FP32,换取吞吐与显存的提升。PyTorch 提供 autocast 与 GradScaler 组合来降低接入成本。

AMP (torch.cuda.amp)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# GradScaler 负责缩放 loss,降低 fp16 路径下梯度下溢的风险。
scaler = torch.cuda.amp.GradScaler()
 
for x, y in loader:
    # 数据搬运逻辑和普通训练循环一致。
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    opt.zero_grad(set_to_none=True)
 
    # autocast 让矩阵乘、卷积等算子自动走更省显存的混合精度路径。
    with torch.cuda.amp.autocast(dtype=torch.float16):
        logits = model(x)
        loss = loss_fn(logits, y)
 
    # backward / step / update 必须和 scaler 协同使用。
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
checkpoint

PyTorch 的 checkpoint 约定是保存 state_dict:至少包含模型参数与优化器状态,必要时加上 scheduler、步数与随机种子。训练恢复的核心是把“运行状态”视为数据,单独保存权重文件不足以恢复完整训练。

Checkpoint save/load (PyTorch)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 把恢复训练最少需要的状态集中到一个字典里。
ckpt = {
    "step": step,
    "model": model.state_dict(),
    "optimizer": opt.state_dict(),
}
# 中间态 checkpoint 通常直接用 torch.save 写成单文件。
torch.save(ckpt, "ckpt.pt")
 
# 恢复时先把 checkpoint 读回当前设备。
ckpt = torch.load("ckpt.pt", map_location=device)
# 先恢复模型参数,再恢复优化器状态,这样学习率和动量都能续上。
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["optimizer"])
step = ckpt["step"]
CRF layer

线性链条件随机场(Linear-chain Conditional Random Field, CRF)是序列标注(Sequence Labeling)里常见的结构化输出层:它不改变 encoder 的表示学习方式,但把标签预测从“逐 token 独立分类”改成“序列级全局最优路径”,以显式建模相邻标签转移约束。

pytorch-crf / torchcrf

pytorch-crf 是常见的第三方 CRF layer 实现,API 以一个 CRF 模块为中心:前向返回 log-likelihood,训练时通常取负作为 loss;解码使用 decode 输出最优标签序列。

Install CRF layer
Shell
1
pip install pytorch-crf

CRF layer (torchcrf) minimal usage
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from torchcrf import CRF
 
num_tags = 5
# CRF 层负责在标签转移层面建模,逐 token softmax 会独立预测每个位置
crf = CRF(num_tags)
 
seq_len, batch = 3, 2
# 每个位置对每个标签的发射分数,通常来自 BiLSTM/BERT 编码器
emissions = torch.randn(seq_len, batch, num_tags)
# 监督标签形状必须和 (seq_len, batch) 对齐
tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long)
 
# 返回整条标签路径的对数似然,逐 token 独立概率不包含转移约束
log_likelihood = crf(emissions, tags)
loss = -log_likelihood                  # 训练时通常最小化负对数似然
 
# Viterbi 解码得到最优标签序列,适合 NER/POS 这类结构化预测任务
best_paths = crf.decode(emissions)
TensorFlow

TensorFlow 的核心执行模型是 eager + graph:默认 eager 便于调试与交互, @tf.function 将 Python 函数 trace 编译成图执行以提升性能与可移植性。训练循环的两条主线分别是:Keras Model.fit 的高层接口,以及基于 tf.GradientTape 的自定义循环。

TensorFlow installation (pip)
Shell
1
2
3
4
5
# CPU-only
pip install -U tensorflow-cpu
 
# Default package (CPU/GPU depends on platform; follow the official pip install guide)
pip install -U tensorflow

TensorFlow quick verify
Python
1
2
3
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices("GPU"))
TensorFlow 计算图与 tf.data

tf.data.Dataset 是 TensorFlow 的输入管线中心:通过 map、 batch、 shuffle、 prefetch 把数据处理做成可并行、可流式的图。生产训练中,输入管线经常成为 GPU 利用率的上限瓶颈,因此应优先把可并行预处理移入 tf.data 体系内。

tf.data input pipeline
Python
1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
 
# 先把内存中的特征和标签包装成 tf.data 流
ds = tf.data.Dataset.from_tensor_slices((features, labels))
# 打乱顺序,避免样本排列对训练造成偏置
ds = ds.shuffle(10000)
# 把预处理并行化,尽量让输入管线跟上 GPU
ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
# 固定 batch 形状更利于图编译和设备执行
ds = ds.batch(128, drop_remainder=True)
# 让 CPU 预处理和 GPU 训练重叠,减少设备空转
ds = ds.prefetch(tf.data.AUTOTUNE)

Custom training step with GradientTape + tf.function
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
 
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # training=True 会打开 dropout、batch norm 更新等训练态行为
        logits = model(x, training=True)
        # loss_fn 负责把标签和 logits 对齐成可反传的标量目标
        loss = loss_fn(y, logits)
    # 从 tape 中回放计算图,求出每个可训练变量的梯度
    grads = tape.gradient(loss, model.trainable_variables)
    # TensorFlow 中显式把 (grad, var) 配对后交给优化器
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss
分布式与断点续训

TensorFlow 多机训练的关键入口是 tf.distribute 与 TF_CONFIG。单机多卡常用 MirroredStrategy;多机训练则更常用 MultiWorkerMirroredStrategy。恢复训练时,真正应该恢复的是“模型参数 + 优化器状态 + 当前步数”,因此标准做法通常是 tf.train.Checkpoint 搭配 CheckpointManager;只导出 SavedModel 不足以恢复训练过程。

TensorFlow: strategy + CheckpointManager
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import tensorflow as tf
 
# 单机多卡可换成 MirroredStrategy;多机时由 TF_CONFIG 描述集群拓扑。
strategy = tf.distribute.MultiWorkerMirroredStrategy()
 
with strategy.scope():
    model = build_model()
    optimizer = tf.keras.optimizers.Adam(3e-4)
 
# checkpoint 负责恢复完整训练状态,覆盖权重、优化器和步数。
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
manager = tf.train.CheckpointManager(ckpt, directory="ckpt_tf", max_to_keep=3)
 
latest = manager.latest_checkpoint
if latest:
    # 这一句会同时恢复模型参数、优化器状态和全局步数。
    ckpt.restore(latest)
Keras 3

Keras 3 是多后端(Multi-backend)深度学习框架:同一份 Keras 代码可以运行在 TensorFlow、JAX、PyTorch 后端上;后端通过 KERAS_BACKEND 环境变量或本地配置文件选择,并且必须在 import Keras 之前确定。对工程而言,这一设计把“模型代码”与“执行后端”分离:可以在同一套高层 API 下切换后端能力,例如在 JAX 上获得更强的编译与 SPMD 体系,或在 PyTorch 上复用既有生态。

Keras 3 installation
Shell
1
pip install -U keras

Select backend before importing keras
Python
1
2
3
4
5
import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" / "torch"
 
import keras
print("backend:", keras.backend.backend())

Keras 3 minimal training
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import keras
from keras import layers
 
model = keras.Sequential([
    # 第一层先把输入映射到隐藏空间,并用 GELU 保持 Transformer 风格的非线性
    layers.Dense(256, activation="gelu"),
    # 最后一层直接输出 10 类 logits;不加 softmax,交给 loss 统一处理
    layers.Dense(10),
])
model.compile(
    # AdamW 是现代深度学习最常见的默认优化器之一
    optimizer=keras.optimizers.AdamW(learning_rate=3e-4),
    # 告诉 loss 输入是 logits,不要再假设已做 softmax
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 训练中同步记录分类准确率,便于和 loss 一起看
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# fit 会统一接管 epoch 循环、日志和验证
model.fit(train_x, train_y, batch_size=128, epochs=3)
Keras 3 的多后端架构

多后端意味着一组实践约束:后端不能在 import 后热切换;部分底层行为(例如随机数、分布式、数值细节)仍由后端决定;性能调优最终仍会落回后端的编译器与 kernel 体系。Keras 3 更适合承担“统一建模接口”,性能工程仍要回到底层后端处理。

命令/API/函数
keras.layers

说明
层与算子组合

示例

Python
1
x = layers.LayerNormalization()(x)

命令/API/函数
keras.Model

说明
可训练模型单元

示例

Python
1
class MyModel(keras.Model): ...

命令/API/函数
keras.ops

说明
后端无关算子层(多后端 API)

示例

Python
1
y = keras.ops.matmul(a, b)

命令/API/函数
keras.optimizers

说明
优化器族

示例

Python
1
opt = keras.optimizers.AdamW(3e-4)
callbacks 与自定义 train_step

Keras 的工程优势不只在 fit(),还在于 callbacks 体系和可覆写的 train_step。前者接管 best checkpoint、early stopping 与 TensorBoard 日志;后者允许在保留 Keras 训练外壳的同时,插入自定义损失、梯度裁剪或多任务逻辑。

Keras: callbacks + custom train_step
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import keras
import tensorflow as tf
 
class MyModel(keras.Model):
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            # 这里仍然让 forward 保持 Keras Model 风格,便于继续复用 fit/evaluate 生态。
            logits = self(x, training=True)
            loss = self.compiled_loss(y, logits)
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        self.compiled_metrics.update_state(y, logits)
        return {m.name: m.result() for m in self.metrics}
 
callbacks = [
    keras.callbacks.ModelCheckpoint("best.keras", save_best_only=True, monitor="val_loss"),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
    keras.callbacks.TensorBoard(log_dir="tb_logs"),
]
JAX

JAX 的训练编程模型是“纯函数(Pure Function)+ 变换(Transformation)+ 编译(XLA Compilation)”。训练代码通常写成不带副作用的函数,然后用 jit/ grad/ vmap/ pmap 把函数变成可微、可并行、可编译的高性能版本。对工程而言,这意味着:参数与 optimizer state 需要显式放入状态对象;更新步骤常用 jit(value_and_grad(...)) 组织成单一的编译热点。

JAX installation (from official docs)
Shell
1
2
3
4
5
# CPU-only
pip install -U jax
 
# NVIDIA GPU (example, CUDA 13)
pip install -U "jax[cuda13]"
函数变换

JAX 的高频核心 API 集中在“函数变换”上。它们本质上都是高阶函数:输入是 Python 函数,输出是新的函数;输出函数具备更强的可微、可并行或可编译特性。

命令/API/函数
jax.jit

说明
把函数编译成 XLA 可执行版本(并缓存编译结果)

示例

Python
1
step = jax.jit(step_fn)

命令/API/函数
jax.grad

说明
对标量输出函数求梯度(反向模式 AD)

示例

Python
1
g = jax.grad(loss_fn)

命令/API/函数
jax.value_and_grad

说明
一次性返回 (value, grad),减少重复前向

示例

Python
1
vg = jax.value_and_grad(loss_fn)

命令/API/函数
jax.vmap

说明
自动向量化,把“单样本函数”提升为“批函数”

示例

Python
1
batched_loss = jax.vmap(loss_fn, in_axes=(None, 0, 0))

命令/API/函数
jax.pmap

说明
跨多设备 SPMD 并行(常用于数据并行)

示例

Python
1
pstep = jax.pmap(step_fn, axis_name="data")
jit

jit 的工程要点是“可 trace 的纯函数”:Python 侧的动态分支、不可哈希的静态参数、以及频繁变化的输入形状都会导致重新 trace 或重新编译,从而出现性能抖动。训练代码通常会把 step 函数写成固定签名,并把配置项通过静态参数或闭包固定。

grad

grad 适用于标量 loss。多输出或同时需要 aux(例如 metrics)时,通常配合 value_and_grad(has_aux=True)。

vmap

vmap 把 batch 维度“隐式传播”到计算图里,常用于 per-example gradients、对比学习的 pairwise 计算、以及把 Python for-loop 从热点路径挪走。

pmap

pmap 以 named axis 组织跨设备 collectives(如 jax.lax.psum),适合以“每个设备一份副本”的方式做数据并行。更通用的分片(sharding)体系通常落在 pjit/mesh 路线,但训练代码层面仍常见以 pmap 组织最小多卡并行示例。

编译式执行
JAX 执行模型

JAX 训练 step 的典型形态是:“状态输入 → 计算 loss → 求梯度 → 用 optimizer 更新状态”。状态一般用 PyTree 组织(字典、元组、dataclass 等嵌套结构),并作为函数参数显式传递。

Minimal JAX training step (sketch)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import jax
import jax.numpy as jnp
 
def loss_fn(params, batch):
    x, y = batch
    # JAX 里前向通常写成纯函数:参数显式传入,函数内部不改全局状态
    logits = model_apply(params, x)
    loss = cross_entropy(logits, y)
    return loss
 
@jax.jit
def step(params, opt_state, batch):
    # 一次前向同时拿到 loss 和梯度,减少重复计算
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    # 优化器状态也是显式对象,需要手动传入传出
    updates, opt_state = optimizer.update(grads, opt_state, params)
    # JAX/Optax 不原地改参数,改为返回更新后的新参数树
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss
XLA

XLA 是 JAX 性能的核心来源:它把 Python 层的计算表达编译成设备侧可执行程序,并做算子融合与内存规划。训练工程里,减少 recompile 与控制输入形状稳定通常比微观优化更有效。

Mesh / NamedSharding / jax.distributed.initialize

JAX 的新一代分片主线围绕 Mesh、 PartitionSpec 与 NamedSharding 展开。它比早期只靠 pmap 更通用,也更适合显式描述参数分片与输入布局。多进程训练时,通常还需要先调用 jax.distributed.initialize() 让各进程加入同一 JAX 集群。

JAX: distributed initialize + NamedSharding
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import jax
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
 
# 多机时要先建立全局进程组;单机调试通常可以省略。
jax.distributed.initialize()
 
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=("data",))
sharding = NamedSharding(mesh, P("data"))
 
# 显式把 batch 按 data 轴切到多个设备,避免依赖隐式复制。
x = jax.device_put(np.ones((len(devices) * 8, 1024), dtype=np.float32), sharding)
Flax

Flax 是基于 JAX 的神经网络库,常见入口是 Linen API:以 Module 表达参数化结构,以 init/apply 显式分离“参数创建”和“前向应用”。训练循环通常围绕 TrainState 把 params 与 optimizer state 统一管理。

Install Flax (typical)
Shell
1
pip install -U flax

命令/API/函数
flax.linen.Module

说明
参数化模块定义

示例

Python
1
2
3
4
5
6
7
8
import flax.linen as nn
 
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)  # compact 允许在前向里直接声明子层,Flax 会自动跟踪参数树
        x = nn.gelu(x)        # 非线性激活放在中间层,提升表达能力
        return nn.Dense(10)(x)  # 输出 10 维 logits,后续交给 loss 或任务头解释

命令/API/函数
Module.init

说明
给定 rng 与输入 shape,初始化参数

示例

Python
1
params = model.init(jax.random.key(0), x)["params"]

命令/API/函数
Module.apply

说明
给定参数执行前向

示例

Python
1
logits = model.apply({"params": params}, x)

命令/API/函数
flax.training.train_state.TrainState

说明
统一管理 step/params/opt_state

示例

Python
1
2
from flax.training.train_state import TrainState
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
Flax checkpoint:checkpoints 与 Orbax

Flax 训练里最常见的恢复接口是 flax.training.checkpoints;更现代、更通用的保存体系则逐渐转向 Orbax。前者上手更快,后者更适合复杂 PyTree、异步写盘和大规模分片状态。

Flax: TrainState checkpoint
Python
1
2
3
4
5
from flax.training import checkpoints
 
# 直接把 TrainState 存成 checkpoint,便于断点续训。
checkpoints.save_checkpoint("ckpt_flax", target=state, step=state.step, keep=3)
state = checkpoints.restore_checkpoint("ckpt_flax", target=state)

Orbax: PyTreeCheckpointer
Python
1
2
3
4
5
import orbax.checkpoint as ocp
 
checkpointer = ocp.PyTreeCheckpointer()
# Orbax 适合更复杂的树状状态与显式的保存策略。
checkpointer.save("orbax_ckpt", {"state": state})
PaddlePaddle

PaddlePaddle 的训练编程覆盖动态图(Dynamic Graph)与静态图(Static Graph)两种执行路径:动态图强调易用与调试;静态图强调编译优化、部署与稳定性能。官方 API 通过 paddle.enable_static() 显式切换到静态图模式。

PaddlePaddle installation (follow official guide)
Shell
1
2
3
4
5
# CPU
pip install -U paddlepaddle
 
# GPU
pip install -U paddlepaddle-gpu

Paddle verify
Python
1
2
3
import paddle
print(paddle.__version__)
paddle.utils.run_check()
执行模式与工具链
动态图

动态图是默认模式,常见训练代码以 paddle.nn.Layer 组织模型,以 paddle.optimizer 更新参数。

静态图

静态图用于追求更强的图级优化与更稳定的部署路径。切换到静态图通常意味着:需要显式构建 program/graph,并使用对应的 executor/engine 执行。实践里更常见的策略是:训练仍以动态图为主,部署阶段再导出静态图或使用官方推理引擎。

产业化工具链

工程体系上,Paddle 生态通常把“训练、推理、部署、端侧”做成一套配套工具链。若目标是快速把模型落到生产场景(OCR、CV、NLP 服务或端侧),这条生态链会显著降低工程摩擦。

Fleet

Fleet 是 Paddle 的分布式训练统一 API:通过 fleet.init 初始化分布式环境,并用 fleet.distributed_optimizer 把普通 optimizer 包装成分布式 optimizer。工程上最常见的是 collective 路线。

Fleet (collective) minimal pattern
Python
1
2
3
4
5
6
7
8
9
10
11
12
import paddle
import paddle.distributed.fleet as fleet
 
# 让当前进程加入 collective 通信组;多卡启动器会提前准备好 rank/world size。
fleet.init(is_collective=True)
# 示例里只放一个线性层,占位说明 Fleet 接的是“普通 Paddle 模型”。
model = paddle.nn.Linear(10, 10)
optimizer = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters())
# DistributedStrategy 是分布式策略入口;真实项目会继续在这里打开 AMP/重算等选项。
strategy = fleet.DistributedStrategy()
# 包装后返回的仍是 optimizer 语义,但 step/backward 会走 Fleet 的分布式实现。
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
PaddleNLP

PaddleNLP 是 PaddlePaddle 生态里的 NLP 与大模型套件。它和 Transformers 的角色有相似之处:提供模型库、tokenizer、Trainer、Taskflow、数据处理、微调与推理工具;差异在于它深度绑定 PaddlePaddle、Paddle Fleet、Paddle Inference 与国产硬件适配链路,更适合已经采用百度飞桨生态或需要落到 Paddle 产业工具链的团队。

安装
Shell
1
pip install -U paddlenlp
Taskflow 与 LLM 套件

Taskflow 更偏“任务 API”:分词、信息抽取、情感分析、文本生成等任务可以通过统一入口快速验证;LLM 套件则覆盖预训练、SFT、PEFT、对齐、量化、推理和统一 checkpoint。选 PaddleNLP 的主要理由通常是训练、推理、压缩和部署都留在 Paddle 生态内,减少跨框架模型转换、算子适配与部署链路割裂。

Python
1
2
3
4
5
from paddlenlp import Taskflow
 
# Taskflow 会按任务名装配模型、tokenizer 与后处理,适合先快速验证任务闭环。
seg = Taskflow("word_segmentation")
print(seg("我爱自然语言处理"))
MindSpore

MindSpore 是华为主导的深度学习训练/推理框架,核心特色是与昇腾(Ascend)AI 处理器和全栈工具链协同。它支持动态图与静态图思想下的训练编程,也提供自动并行、图优化、端边云部署等能力。工程上,MindSpore 的选型通常和硬件平台绑定:如果目标环境以 Ascend NPU 为主,MindSpore / MindSpeed / CANN 这条链路的集成度更高。

安装与验证

MindSpore 的安装方式与设备类型强相关。CPU、GPU、Ascend 对应不同 wheel 与运行时依赖;Ascend 环境还需要匹配 CANN、驱动和固件版本。实际项目里应优先按官方安装矩阵固定版本,不建议只凭 pip install 猜测。

Shell
1
2
# CPU 环境可用于语法验证和轻量实验;Ascend/GPU 环境需按官方矩阵安装对应 wheel。
pip install mindspore

Python
1
2
3
4
5
import mindspore as ms
 
# context 决定执行设备和模式;Ascend 环境通常把 device_target 设为 "Ascend"。
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
print(ms.__version__)
特色与选型边界

MindSpore 的特色是围绕图编译、自动并行和 Ascend 硬件协同建立一整套工程路径。它适合长期运行在昇腾集群、需要使用 CANN / MindSpeed / MindSpore ModelZoo / MindFormers 生态的项目;若团队已有大量 PyTorch/Transformers/DeepSpeed 资产,迁移成本需要单独评估。

OneFlow

OneFlow 是国产深度学习框架,设计重点在分布式训练抽象。它提出的 SBP(Split / Broadcast / Partial)语义把张量在设备网格上的放置方式显式化:某个维度可以切分(Split),某个张量可以复制(Broadcast),归约前的部分值可以处于 Partial 状态。这个设计让数据并行、模型并行和混合并行都能统一表达。

Global Tensor 与 SBP

OneFlow 的 Global Tensor 把多设备集群抽象成一个统一计算空间。开发者声明 placement 与 sbp,框架负责把张量切到对应设备并插入必要通信。它更适合研究分布式张量布局、训练系统和国产框架生态;普通 LLM 微调项目若主要使用 Hugging Face 生态,PyTorch 路线的工程摩擦通常更小。

Python
1
2
3
4
5
6
import oneflow as flow
 
placement = flow.placement("cuda", ranks=[0, 1])
# sbp.split(0) 表示把第 0 维按设备切分,常用于 batch 维数据并行。
x = flow.randn(8, 1024, placement=placement, sbp=flow.sbp.split(0))
print(x.placement, x.sbp)
基础训练框架怎么选
框架 核心特色 优先选择场景
PyTorch 生态最大、动态图体验强、Transformers/PEFT/vLLM/DeepSpeed/FSDP 集成最完整 LLM 训练、微调、推理服务和研究原型的默认起点。
TensorFlow / Keras 3 生产部署、SavedModel、TF Serving、Keras 高层 API 与多后端建模体验 已有 TensorFlow 资产、移动端/服务端部署链路成熟,或团队偏好 Keras 训练接口。
JAX / Flax 函数式编程、XLA 编译、显式 sharding、适合大规模研究系统 需要强编译优化、SPMD 分片、研究型训练系统或 TPU/JAX 生态。
PaddlePaddle / PaddleNLP 中文产业生态、Paddle Fleet、NLP/视觉/OCR/推理部署工具链完整 企业已经采用飞桨生态,或希望训练、压缩、推理、部署沿同一国产工具链落地。
MindSpore / MindSpeed Ascend 亲和、图优化、自动并行、CANN 与昇腾硬件栈深度协同 目标算力主要是华为昇腾,且团队需要利用 Ascend 原生训练/推理优化。
OneFlow SBP 分布式张量语义、分布式训练抽象清晰 研究分布式框架、国产训练系统,或已有 OneFlow 生态资产。
经典机器学习工程框架

这类库不属于大模型主线,但在特征工程、表格任务、以及小模型 baseline 中仍然高频存在。它们的接口几乎都围绕 fit/ predict/ predict_proba,工程上更强调数据清洗与特征一致性。

主流库怎么分工

经典机器学习工程的主线可以按职责拆成五层:scikit-learn 负责统一 API 和基线建模,XGBoost / LightGBM / CatBoost 负责高性能表格树模型,statsmodels 负责统计建模与显著性分析,imbalanced-learn 负责类别不均衡处理,Optuna / SHAP / joblib 负责调参、解释和交付。它们覆盖的是表格、稀疏特征、小中型监督学习、聚类、异常检测和可解释建模。

库 / 框架 核心定位 最适合的任务 选型边界
scikit-learn 经典 ML 的统一 estimator API、预处理、Pipeline、模型选择、指标 逻辑回归、SVM、随机森林、KMeans、PCA、IsolationForest、表格 baseline 中小规模 CPU 任务最稳;超大规模训练和 GPU 深度学习任务应转向专门框架。
XGBoost 成熟的梯度提升树实现,正则化、缺失值处理和 GPU 路线完整 表格分类/回归、风控、排序、特征工程强的业务模型 资料和生态最成熟;类别特征需额外处理,训练速度未必总是最快。
LightGBM 高速 GBDT,直方图算法、leaf-wise 生长和大规模稀疏特征支持突出 大样本表格、CTR/CVR、广告预估、需要快速迭代的树模型 速度优势明显;leaf-wise 生长需要控制叶子数、深度和早停,避免过拟合。
CatBoost 对类别特征友好的 GBDT,内置类别处理和 ordered boosting 思路 类别列多、类别基数高、手写 target encoding 风险大的表格任务 类别特征路径稳定;极简数值特征任务上未必优于 LightGBM/XGBoost。
statsmodels 统计建模、公式接口、参数估计、置信区间、显著性检验 线性回归、Logit、时间序列统计模型、可解释分析报告 更偏统计推断;生产预测 pipeline 通常仍由 scikit-learn 或 GBDT 承担。
imbalanced-learn 类别不均衡处理,与 scikit-learn Pipeline 协同 欺诈检测、风控、罕见病识别、告警分类等少数类极少的任务 重采样必须只发生在训练 fold 内;把重采样放在全量数据前处理会造成数据泄漏。
Optuna 自动超参数优化,define-by-run 搜索空间,支持 pruning GBDT、scikit-learn pipeline、深度学习训练脚本的调参 调参预算有限时优先;前提是验证集和目标指标可信。
SHAP 模型解释,尤其适合树模型的 TreeExplainer 特征贡献分析、风控审计、业务解释、模型回归排查 解释结果依赖数据分布和特征相关性;不能把 SHAP 当因果结论。
joblib / skops 模型持久化与安全交付 保存 scikit-learn pipeline、离线推理、批处理服务 pickle/joblib 要控制加载来源;跨组织交付更应考虑更安全的格式或导出方案。
统一安装
Shell
1
2
3
4
5
6
7
# scikit-learn 提供经典 estimator、Pipeline、交叉验证和预处理。
# XGBoost/LightGBM/CatBoost 是表格任务常用的 GBDT 工程库。
pip install -U scikit-learn xgboost lightgbm catboost
 
# statsmodels 服务统计推断;imbalanced-learn 处理不均衡采样。
# Optuna 做超参数搜索;SHAP 做模型解释;joblib 保存 sklearn 制品。
pip install -U statsmodels imbalanced-learn optuna shap joblib
任务到主流库选择
任务 优先库 推荐起点 升级路线
二分类 / 多分类 baseline scikit-learn LogisticRegression、RandomForestClassifier、Pipeline 指标稳定后再切 XGBoost / LightGBM / CatBoost。
表格强特征分类/回归 LightGBM / XGBoost / CatBoost LightGBM 快速迭代,XGBoost 做稳健对照,类别列多时试 CatBoost 加入 Optuna 调参、SHAP 解释、特征选择和校准。
类别极不均衡 imbalanced-learn + scikit-learn / GBDT class_weight、阈值调优、PR-AUC,再评估 SMOTE 等重采样 引入代价敏感学习、分层抽样、hard negative 采样。
统计解释 / 显著性分析 statsmodels OLS、Logit、公式接口、summary 报告 预测系统再转成 scikit-learn pipeline 或 GBDT。
聚类 / 分群 / 探索 scikit-learn PCA + KMeans,或 DBSCAN / AgglomerativeClustering 聚类结果经人工命名和业务验证后,再作为特征或标签候选。
异常检测 scikit-learn IsolationForest、OneClassSVM、LocalOutlierFactor 将异常检测作为召回层,再接人工审核或监督分类器。
调参 Optuna 围绕业务指标定义 objective,控制 n_trials 和搜索边界 加入 pruning、多进程/数据库存储、分阶段搜索。
解释与审计 SHAP TreeExplainer 解释 GBDT,全局 summary + 局部 top features 结合漂移监控、特征稳定性和业务规则审查。
scikit-learn

scikit-learn 是经典机器学习工程的默认 baseline 工具。它把预处理、特征组合、模型训练、交叉验证、指标评估统一在一套 estimator API 下:所有模型基本都遵循 fit、 predict、 predict_proba、 score 这些入口。对 AI 工程来说,它最重要的价值是快速建立强 baseline,并把特征工程写成可复现 pipeline。

Shell
1
2
# 安装 scikit-learn 主包;它依赖 numpy/scipy/joblib/threadpoolctl。
pip install -U scikit-learn
常用API

命令/API/函数
train_test_split

说明
把数据切成训练集和验证/测试集。真实项目里要固定 random_state,并在分类任务中使用 stratify 保持标签比例。

示例

Python
1
2
3
4
5
6
7
8
9
from sklearn.model_selection import train_test_split
 
X_train, X_valid, y_train, y_valid = train_test_split(
    X,
    y,
    test_size=0.2,
    stratify=y,       # 分类任务保持正负样本比例,避免验证集分布漂移。
    random_state=42,  # 固定切分,保证实验可复现。
)

命令/API/函数
Pipeline

说明
把预处理和模型训练串成单个 estimator。这样交叉验证、保存模型和线上推理都会走同一条特征处理路径,减少训练/推理特征不一致。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
 
pipe = Pipeline(
    steps=[
        # 标准化器会在 fit 阶段只读取训练集统计量,避免验证集信息泄漏。
        ("scale", StandardScaler()),
        # max_iter 给逻辑回归更多迭代步,减少复杂特征下未收敛的概率。
        ("clf", LogisticRegression(max_iter=1000)),
    ]
)
# fit 会按顺序训练 scale 和 clf 两个步骤。
pipe.fit(X_train, y_train)
# 第二列是正类概率,常用于 AUC、阈值选择和概率校准。
proba = pipe.predict_proba(X_valid)[:, 1]

命令/API/函数
ColumnTransformer

说明
按列类型应用不同预处理:数值列做标准化,类别列做 one-hot,文本列可接 TF-IDF。表格任务中这是把特征工程写进模型产物的关键接口。

示例

Python
1
2
3
4
5
6
7
8
9
10
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
 
preprocess = ColumnTransformer(
    transformers=[
        ("num", StandardScaler(), ["age", "income"]),
        # handle_unknown="ignore" 让线上遇到新类别时不直接报错。
        ("cat", OneHotEncoder(handle_unknown="ignore"), ["country", "device"]),
    ]
)

命令/API/函数
cross_val_score / GridSearchCV

说明
交叉验证与网格搜索评估的是整条 pipeline;预处理步骤会在每个 fold 内重新 fit,避免验证信息泄漏进训练。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from sklearn.model_selection import GridSearchCV
 
search = GridSearchCV(
    # estimator 可以是单个模型,也可以是包含预处理的完整 Pipeline。
    estimator=pipe,
    # 双下划线表示进入 Pipeline 中名为 clf 的步骤,再调它的 C 参数。
    param_grid={"clf__C": [0.1, 1.0, 10.0]},
    # roc_auc 使用概率排序质量,比固定阈值 accuracy 更适合二分类筛选。
    scoring="roc_auc",
    # 每个 fold 都会重新 fit 预处理器,验证 fold 不参与特征统计。
    cv=5,
)
search.fit(X_train, y_train)
print(search.best_params_, search.best_score_)

scikit-learn 的工程边界也很清楚:它适合中小规模表格数据、传统特征工程、快速 baseline 和可解释小模型。大规模深度学习训练、GPU 分布式训练、LLM 微调和在线生成服务,应切到 PyTorch/Transformers/DeepSpeed/vLLM 等栈。

经典算法到 sklearn 类名的映射
算法 常用类 工程使用要点
逻辑回归(Logistic Regression) sklearn.linear_model.LogisticRegression 强 baseline;数值特征通常需要标准化;类别不均衡时关注 class_weight、阈值和 AUC/F1。
支持向量机(SVM) sklearn.svm.SVC、 LinearSVC 小中型特征空间表现稳定;核 SVM 训练成本高,线性 SVM 更适合高维稀疏文本特征。
决策树(Decision Tree) sklearn.tree.DecisionTreeClassifier 可解释性强;单树容易过拟合,通常用深度、叶子样本数和剪枝参数控制复杂度。
随机森林(Random Forest) sklearn.ensemble.RandomForestClassifier 通过 bagging 降低方差;适合表格 baseline 和特征重要性分析,推理延迟随树数量增长。
k 近邻(KNN) sklearn.neighbors.KNeighborsClassifier 训练几乎没有成本,推理依赖距离搜索;特征尺度必须处理好,高维空间容易退化。
朴素贝叶斯(Naive Bayes) MultinomialNB、 GaussianNB、 BernoulliNB 文本分类和小数据 baseline 很常见;根据特征分布选择多项式、连续高斯或二值 Bernoulli 版本。
ref-2 算法到工程入口总覆盖表

ref-2 的机器学习算法在工程上可以落到下表。这里强调可执行入口和项目边界;算法原理仍放在 ref-2 的理论章节中讲。

ref-2 算法 / 方法 主流工程入口 使用边界
感知机(Perceptron) sklearn.linear_model.Perceptron、 SGDClassifier(loss="perceptron") 线性可分 baseline、在线学习教学和大规模稀疏特征起点。
逻辑回归 / 最大熵(MaxEnt) LogisticRegression、 SGDClassifier(loss="log_loss") 分类强 baseline;最大熵在工程中通常对应多项逻辑回归。
线性回归、Ridge、Lasso、Elastic Net LinearRegression、 Ridge、 Lasso、 ElasticNet 回归、可解释建模、低延迟服务和强规则特征。
朴素贝叶斯(Naive Bayes) MultinomialNB、 ComplementNB、 BernoulliNB、 GaussianNB 文本分类、小样本 baseline、离散计数特征;ComplementNB 更适合不均衡文本。
k 近邻(KNN) KNeighborsClassifier、 KNeighborsRegressor、 NearestNeighbors 距离检索、小规模分类回归、近邻图构建;高维场景需要降维或向量索引。
决策树 / 随机森林 DecisionTreeClassifier、 DecisionTreeRegressor、 RandomForestClassifier、 RandomForestRegressor 表格 baseline、非线性特征交互、特征重要性;单树需限制复杂度。
支持向量机(SVM) SVC、 LinearSVC、 SVR、 OneClassSVM 中小规模分类回归、文本线性分类、异常检测;核方法训练成本高。
线性判别分析 / 二次判别分析 LinearDiscriminantAnalysis、 QuadraticDiscriminantAnalysis LDA 可做分类和监督降维;QDA 更灵活,也更依赖样本量。
GBDT、XGBoost、LightGBM、CatBoost GradientBoosting*、 XGB*、 LGBM*、 CatBoost* 表格强模型、CTR/CVR、风控、排序和强特征业务任务。
聚类分析 KMeans、 MiniBatchKMeans、 DBSCAN、 OPTICS、 Birch、 AgglomerativeClustering、 SpectralClustering 分群、语料探索、样本去重和近邻图分析;簇编号需要业务解释。
概率密度估计 / 异常检测 GaussianMixture、 KernelDensity、 IsolationForest、 LocalOutlierFactor、 OneClassSVM 密度评分、软聚类、异常候选召回和数据质量检查。
降维与可视化 PCA、 TruncatedSVD、 TSNE、 umap.UMAP PCA/SVD 可进入生产特征链路;t-SNE/UMAP 更常用于探索和可视化。
主题模型:隐含狄利克雷分布(Latent Dirichlet Allocation, LDA) sklearn.decomposition.LatentDirichletAllocation、gensim 无监督主题发现;缩写 LDA 与线性判别分析相同,语境决定含义。
HMM、CRF、MEMM、结构化感知机 hmmlearn、sklearn-crfsuite、torchcrf、seqlearn、pystruct、自定义动态规划 序列标注、词法分析、NER、状态序列;深度模型常把 CRF 作为解码层。
半监督 / 弱监督 / 主动学习 SelfTrainingClassifier、 LabelPropagation、Snorkel、Cleanlab、modAL、scikit-activeml 少标注数据、伪标签、规则标签融合、样本挑选和标注闭环。
通用强化学习 Gymnasium、Stable-Baselines3、CleanRL、Ray RLlib、PettingZoo Q-Learning、SARSA、DQN、PPO、Actor-Critic、多智能体环境;LLM RL 另见 OpenRLHF 和 verl。
感知机、最大熵与在线线性模型

ref-2 中的感知机、最大熵模型和浅层线性模型,在工程上通常走 linear_model。最大熵分类器对应逻辑回归;在线学习场景可用 SGDClassifier 的 partial_fit 分批更新。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sklearn.linear_model import Perceptron, SGDClassifier
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
 
# 感知机只学习线性分界面,适合做在线线性分类的最小 baseline。
perceptron = Perceptron(
    max_iter=20,
    tol=1e-3,
    random_state=42,
)
perceptron.fit(X_train, y_train)
print(classification_report(y_valid, perceptron.predict(X_valid)))
 
# 最大熵模型在 sklearn 中通常写成 log_loss 线性分类器。
# average=True 会对 SGD 权重做平均,常能提升线上稳定性。
maxent_online = SGDClassifier(
    loss="log_loss",
    penalty="l2",
    alpha=1e-5,
    learning_rate="optimal",
    average=True,
    random_state=42,
)
 
# partial_fit 支持分批训练;第一批必须给出完整类别集合。
classes = sorted(set(y_train))
for X_batch, y_batch in stream_training_batches():
    maxent_online.partial_fit(X_batch, y_batch, classes=classes)
 
# 输出正类概率,便于后续做阈值选择、校准和业务路由。
proba = maxent_online.predict_proba(X_valid)[:, 1]
线性模型族:LinearRegression、Ridge、Lasso 与 ElasticNet

ref-2 中提到的线性回归、L1/L2 正则化和 Elastic Net,在工程上主要落到 scikit-learn 的 linear_model 模块。它们适合做可解释 baseline、低延迟模型、强特征表格任务和高维稀疏文本分类。

算法 scikit-learn 类 工程使用要点
线性回归 LinearRegression 最小二乘 baseline;特征共线性强时系数不稳定。
岭回归(Ridge) Ridge、 RidgeClassifier L2 正则化让系数更稳定,适合多重共线性或高维特征。
Lasso Lasso L1 正则化产生稀疏系数,常用于特征选择和解释。
Elastic Net ElasticNet 混合 L1/L2,适合相关特征成组出现的表格任务。
SGD 线性模型 SGDClassifier、 SGDRegressor 适合超大规模稀疏特征和流式训练,但学习率与正则化更敏感。
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sklearn.linear_model import ElasticNet, Ridge
from sklearn.metrics import mean_absolute_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
 
# 线性模型的系数直接受特征尺度影响,标准化必须进入 Pipeline。
ridge_pipe = Pipeline(
    steps=[
        # StandardScaler 只在 fit 阶段从训练集估计均值和方差。
        ("scale", StandardScaler()),
        # alpha 是 L2 正则强度;值越大,系数越保守,方差越低。
        ("model", Ridge(alpha=1.0)),
    ]
)
 
elastic_pipe = Pipeline(
    steps=[
        ("scale", StandardScaler()),
        # alpha 控制总正则强度;l1_ratio 控制 L1 在正则项中的比例。
        # l1_ratio 越接近 1,越容易把弱特征系数压成 0。
        ("model", ElasticNet(alpha=0.01, l1_ratio=0.5, max_iter=5000)),
    ]
)
 
# fit 会依次训练标准化器和 Ridge;验证集只调用 transform,不重新估计统计量。
ridge_pipe.fit(X_train, y_train)
pred = ridge_pipe.predict(X_valid)
print("MAE:", mean_absolute_error(y_valid, pred))
监督降维与判别分析:LDA / QDA

线性判别分析(Linear Discriminant Analysis, LDA)在 ref-2 中有两个身份:分类器,以及利用标签信息做监督降维的线性方法。它假设各类别共享协方差矩阵,因此决策边界是线性的;二次判别分析(Quadratic Discriminant Analysis, QDA)允许每类有不同协方差矩阵,边界更灵活,也更容易受样本量影响。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.metrics import accuracy_score
 
# LDA 可以直接作为分类器使用。
lda = LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)
print("LDA acc:", accuracy_score(y_valid, lda.predict(X_valid)))
 
# n_components 控制监督降维后的维度,上限受类别数限制。
# 这个表示可以继续交给可视化、聚类或轻量分类器。
projector = LinearDiscriminantAnalysis(n_components=2)
X_train_lda = projector.fit_transform(X_train, y_train)
X_valid_lda = projector.transform(X_valid)
 
# QDA 边界更灵活,样本少或特征维高时更需要正则化。
qda = QuadraticDiscriminantAnalysis(reg_param=0.1)
qda.fit(X_train, y_train)
分类与回归 estimator 对照补充
算法族 分类入口 回归 / 相关入口
SVM SVC、 LinearSVC、 NuSVC SVR、 LinearSVR、 OneClassSVM
kNN KNeighborsClassifier KNeighborsRegressor、 NearestNeighbors
决策树 DecisionTreeClassifier DecisionTreeRegressor
随机森林 RandomForestClassifier RandomForestRegressor
朴素贝叶斯 MultinomialNB、 ComplementNB、 BernoulliNB、 GaussianNB 主要服务分类;连续目标通常改用线性/树模型。

ComplementNB 常用于类别不均衡的文本分类; LinearSVC 适合高维稀疏 TF-IDF;核 SVC 更适合中小样本。SVM、KNN 和线性模型通常都需要标准化或合适的特征缩放。

常见分类器家族快速基线

经典分类任务可以先跑一组轻量 baseline:朴素贝叶斯检验文本/计数特征是否已经足够强,线性 SVM 检验高维稀疏边界,KNN 检验距离度量,决策树与随机森林检验非线性特征交互。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from sklearn.naive_bayes import ComplementNB, GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier
 
baseline_models = {
    # ComplementNB 适合非负计数/TF-IDF 文本特征,常作为文本分类强 baseline。
    "complement_nb": ComplementNB(alpha=1.0),
    # GaussianNB 假设连续特征在每个类别内近似高斯分布,适合小数据快速试探。
    "gaussian_nb": GaussianNB(),
    # LinearSVC 适合高维稀疏特征;它默认不给 predict_proba,评估时常看 decision_function。
    "linear_svm": LinearSVC(C=1.0, class_weight="balanced"),
    # KNN 的效果取决于特征尺度和距离定义,使用前通常需要标准化或归一化。
    "knn": KNeighborsClassifier(n_neighbors=15, weights="distance"),
    # 单棵树可解释性强,max_depth 用来限制树记住训练集噪声。
    "tree": DecisionTreeClassifier(max_depth=8, min_samples_leaf=20, random_state=42),
    # 随机森林通过多棵树投票降低方差,n_jobs=-1 使用本机所有 CPU 核。
    "forest": RandomForestClassifier(n_estimators=300, class_weight="balanced", n_jobs=-1, random_state=42),
}
 
for name, estimator in baseline_models.items():
    # 每个模型都遵循 fit/predict 协议,便于统一接入评估脚本。
    estimator.fit(X_train, y_train)
    pred = estimator.predict(X_valid)
    # macro F1 给每个类别相同权重,适合观察少数类是否被模型忽略。
    print(name, f1_score(y_valid, pred, average="macro"))
监督学习完整 Pipeline

经典机器学习项目最容易出问题的位置通常在特征处理。推荐把数值列、类别列、文本列的预处理写进 ColumnTransformer,再和模型一起封装为 Pipeline。这样训练、交叉验证、保存和线上推理会使用同一条特征路径。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import joblib
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
 
# parquet 常用于离线特征表;读入后先显式声明标签列和特征列。
df = pd.read_parquet("train.parquet")
target = "label"
numeric_cols = ["age", "income", "days_since_signup"]
categorical_cols = ["country", "device"]
text_col = "comment"
 
# X 只包含模型可见特征;y 单独保存,避免把标签误放进特征工程。
X = df[numeric_cols + categorical_cols + [text_col]]
y = df[target]
 
X_train, X_valid, y_train, y_valid = train_test_split(
    X,
    y,
    test_size=0.2,
    # stratify 保持正负样本比例,避免验证指标受切分偶然性影响。
    stratify=y,
    random_state=42,
)
 
numeric_pipe = Pipeline(
    steps=[
        # 中位数填补对极端值更稳,适合作为数值特征的默认起点。
        ("impute", SimpleImputer(strategy="median")),
        # 线性模型和 SVM 对尺度敏感,标准化能让优化更稳定。
        ("scale", StandardScaler()),
    ]
)
 
categorical_pipe = Pipeline(
    steps=[
        # most_frequent 给缺失类别一个稳定替代值,避免 one-hot 阶段报错。
        ("impute", SimpleImputer(strategy="most_frequent")),
        # 线上出现新类别时忽略该列,保证服务不会因未知枚举直接失败。
        ("onehot", OneHotEncoder(handle_unknown="ignore")),
    ]
)
 
preprocess = ColumnTransformer(
    transformers=[
        # 三个子管线会并行处理不同列,最后拼成同一个特征矩阵。
        ("num", numeric_pipe, numeric_cols),
        ("cat", categorical_pipe, categorical_cols),
        # 文本列直接走 TF-IDF,适合短文本 baseline 和工单/评论类特征。
        ("txt", TfidfVectorizer(max_features=50000, ngram_range=(1, 2)), text_col),
    ]
)
 
pipe = Pipeline(
    steps=[
        ("prep", preprocess),
        # class_weight="balanced" 按类别频率自动加权,缓解正负样本不均衡。
        ("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
    ]
)
 
# fit 会训练预处理器和分类器;验证集不会参与任何参数估计。
pipe.fit(X_train, y_train)
# predict_proba 的第二列是正类概率,适合 AUC、阈值路由和校准分析。
proba = pipe.predict_proba(X_valid)[:, 1]
# 0.5 只是默认阈值;生产阈值应按召回、精度或成本函数单独选择。
pred = (proba >= 0.5).astype(int)
 
print("AUC:", roc_auc_score(y_valid, proba))
print(classification_report(y_valid, pred))
 
# 保存整条 pipeline,线上推理才能复用同样的填补、编码和向量化规则。
joblib.dump(pipe, "logreg_pipeline.joblib")
回归任务 Pipeline

回归任务和分类任务共享大部分预处理逻辑,差异在模型头与指标。线性回归适合解释性和强 baseline;随机森林回归能捕捉非线性;GBDT 通常是表格回归的强默认项。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
 
# 回归任务的标签是连续值;这里用房价作为目标变量。
X = df[["area", "bedrooms", "city", "building_type"]]
y = df["price"]
 
X_train, X_valid, y_train, y_valid = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42,
)
 
preprocess = ColumnTransformer(
    transformers=[
        # 类别列 one-hot;未知城市或楼型在推理时忽略,避免线上报错。
        ("cat", OneHotEncoder(handle_unknown="ignore"), ["city", "building_type"]),
    ],
    # 数值列 area、bedrooms 直接透传给随机森林。
    remainder="passthrough",
)
 
model = Pipeline(
    steps=[
        ("prep", preprocess),
        # 树模型对单调变换和特征尺度不敏感,适合先做非线性回归 baseline。
        ("reg", RandomForestRegressor(n_estimators=300, random_state=42, n_jobs=-1)),
    ]
)
 
# fit 同时保存 one-hot 类别映射和随机森林的树结构。
model.fit(X_train, y_train)
pred = model.predict(X_valid)
# MAE 更接近平均绝对业务误差;RMSE 会放大大额预测错误。
print("MAE:", mean_absolute_error(y_valid, pred))
print("RMSE:", mean_squared_error(y_valid, pred, squared=False))
XGBoost、LightGBM、CatBoost

这三类库都属于梯度提升树(Gradient Boosted Decision Trees, GBDT)工程栈,常用于表格数据、排序、广告预估、风控、CTR/CVR 预估、特征工程强的业务模型。它们和神经网络的差异在于:模型由大量树组成,训练过程围绕残差/梯度逐轮加树,特征缺失、离散特征、非线性组合和小中型表格数据通常处理得很强。

库 特色 优先选择场景
XGBoost 生态成熟、正则化与缺失值处理稳定,CPU/GPU 训练路径都常见。 需要强 baseline、比赛/工业表格任务、希望模型行为和资料最容易查证。
LightGBM 训练速度快,直方图算法、leaf-wise 生长和大规模稀疏特征支持强。 样本量或特征量较大、需要快速迭代、CTR/排序/广告预估类任务。
CatBoost 类别特征处理能力强,对 target leakage 和类别编码有专门设计。 类别特征很多、类别基数高、希望减少手写 target encoding 的表格任务。
Shell
1
2
# 三个库都提供 sklearn 风格 API,也各自保留原生训练接口。
pip install -U xgboost lightgbm catboost
统一的 sklearn 风格训练骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score
 
models = {
    "xgb": XGBClassifier(
        # n_estimators 是最多加多少棵树;后面可再配合 early stopping。
        n_estimators=500,
        # max_depth 限制单棵树复杂度,越大越容易记住训练集细节。
        max_depth=6,
        # learning_rate 越小,每棵树贡献越小,通常需要更多树。
        learning_rate=0.05,
        # subsample 和 colsample_bytree 做行/列采样,降低过拟合和训练成本。
        subsample=0.8,
        colsample_bytree=0.8,
        eval_metric="logloss",
    ),
    "lgbm": LGBMClassifier(
        n_estimators=500,
        # num_leaves 控制 leaf-wise 树的容量,常比 max_depth 更关键。
        num_leaves=63,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
    ),
    "cat": CatBoostClassifier(
        # CatBoost 使用 iterations 表示树的轮数,语义接近 n_estimators。
        iterations=500,
        depth=6,
        learning_rate=0.05,
        loss_function="Logloss",
        # 训练日志由外层实验系统记录时,示例里关闭逐轮输出。
        verbose=False,
    ),
}
 
for name, model in models.items():
    # 三个库都提供 sklearn 风格接口,便于统一纳入 pipeline 和评估脚本。
    model.fit(X_train, y_train)
    # 二分类业务通常看正类概率,再按 AUC、F1 或业务阈值评估。
    proba = model.predict_proba(X_valid)[:, 1]
    auc = roc_auc_score(y_valid, proba)
    print(name, auc)
CatBoost 的类别特征入口

CatBoost 的关键优势之一是直接接收类别列。工程上应显式列出类别特征列,避免把类别 ID 误当连续数值处理。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from catboost import CatBoostClassifier, Pool
 
# cat_features 必须使用原始类别列名或列索引,不能提前错误地转成连续数值。
cat_features = ["country", "device", "campaign_id"]
 
train_pool = Pool(
    data=train_df[feature_cols],
    label=train_df["label"],
    # 这些列会走 CatBoost 的有序目标统计和类别特征处理路径。
    cat_features=cat_features,
)
valid_pool = Pool(
    data=valid_df[feature_cols],
    label=valid_df["label"],
    # 训练集和验证集必须使用同一套类别特征声明。
    cat_features=cat_features,
)
 
model = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.03,
    depth=8,
    loss_function="Logloss",
    eval_metric="AUC",
    # 每 100 轮打印一次,便于观察验证集是否平台化。
    verbose=100,
)
# use_best_model=True 会保留验证集指标最好的迭代,避免固定使用最后一轮。
model.fit(train_pool, eval_set=valid_pool, use_best_model=True)

这些库和深度学习栈经常共存:GBDT 负责表格强特征 baseline 或线上轻量模型,深度模型负责文本、图像、序列、多模态或大规模表示学习。推荐系统和广告预估里也常见两阶段组合:GBDT 生成强 tabular baseline,深度模型再处理 embedding、序列行为和复杂交互。

GBDT 早停与验证集

树模型训练时也需要早停。早停不应只看训练集指标;必须准备独立验证集,让模型在验证集指标不再改善时停止加树。不同库 API 细节不同,但工程原则一致:验证集来自训练切分,测试集只用于最终报告。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from lightgbm import LGBMClassifier, early_stopping, log_evaluation
from sklearn.metrics import roc_auc_score
 
lgbm = LGBMClassifier(
    # n_estimators 设大一些,把停止点交给验证集 early stopping 决定。
    n_estimators=5000,
    learning_rate=0.03,
    # num_leaves 是 LightGBM 控制树容量的核心参数。
    num_leaves=63,
    # 行采样和列采样降低树之间的相关性。
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
)
 
lgbm.fit(
    X_train,
    y_train,
    eval_set=[(X_valid, y_valid)],
    eval_metric="auc",
    # 100 轮无提升就停止,避免盲目把树继续叠深。
    callbacks=[early_stopping(100), log_evaluation(100)],
)
 
proba = lgbm.predict_proba(X_valid)[:, 1]
print("AUC:", roc_auc_score(y_valid, proba))
print("best_iteration:", lgbm.best_iteration_)
XGBoost 原生 DMatrix 路线

XGBoost 的 sklearn API 适合快速接入;原生 DMatrix 路线更适合复杂训练参数、缺失值处理、ranking objective 和更接近底层的调优。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import xgboost as xgb
from sklearn.metrics import roc_auc_score
 
# DMatrix 是 XGBoost 的原生数据容器,会保存特征矩阵、标签和缺失值信息。
dtrain = xgb.DMatrix(X_train, label=y_train)
dvalid = xgb.DMatrix(X_valid, label=y_valid)
 
params = {
    # binary:logistic 输出正类概率,适合二分类 AUC/F1/阈值评估。
    "objective": "binary:logistic",
    # eval_metric 决定 early stopping 观察哪个验证指标。
    "eval_metric": "auc",
    # max_depth 控制单棵树深度,直接影响模型容量和过拟合风险。
    "max_depth": 6,
    # eta 是 XGBoost 原生接口中的学习率。
    "eta": 0.03,
    # 行采样和列采样让每轮树看到不同子空间,提升泛化。
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    # hist 使用直方图算法,通常是 CPU 大表格任务的高效默认项。
    "tree_method": "hist",
}
 
booster = xgb.train(
    params=params,
    dtrain=dtrain,
    # 轮数上限设大一些,真实停止点由验证集 early stopping 决定。
    num_boost_round=5000,
    # 同时打印 train/valid,便于区分欠拟合、过拟合和指标噪声。
    evals=[(dtrain, "train"), (dvalid, "valid")],
    early_stopping_rounds=100,
    verbose_eval=100,
)
 
# 预测时只使用验证集最优轮数以内的树,避免把过拟合轮次带入评估。
proba = booster.predict(dvalid, iteration_range=(0, booster.best_iteration + 1))
print("AUC:", roc_auc_score(y_valid, proba))
半监督、弱监督与主动学习工程库

ref-2 中的半监督学习、弱监督学习和主动学习,在工程上通常表现为一套“少量人工标签 + 大量未标注数据 + 伪标签/弱标注/标注预算”的流程。scikit-learn 提供半监督 baseline;Snorkel、Cleanlab、modAL 等库更偏标签治理和标注闭环。

方向 常用库 / 类 工程定位
Self-Training sklearn.semi_supervised.SelfTrainingClassifier 用高置信预测给未标注样本打伪标签,再迭代训练。
图半监督 LabelPropagation、 LabelSpreading 在样本相似图上传播少量标签,适合流形结构明显的小中规模数据。
Co-Training 两套 view + 两个分类器 + 高置信伪标签循环 scikit-learn 没有核心类;通常在文本 view、结构化 view、图像 view 之间手写编排。
半监督 SVM 历史/专用实现,自定义 Transductive SVM 路线 现代工程里较少作为默认项,更多被伪标签、图传播和深度半监督方法替代。
弱监督 Snorkel、Cleanlab 管理弱标注函数、标签噪声、疑似错标样本和数据质量问题。
主动学习 modAL、scikit-activeml 根据不确定性或多样性挑选最值得人工标注的样本。
多实例学习(MIL) 深度 MIL 自定义训练、PyTorch 生态实现 bag-level 标签训练 instance-level 表示,常见于医学影像、文档包和弱标注检测。
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.metrics import classification_report
 
# y_train 中 -1 表示未标注样本,这是 sklearn 半监督接口的约定。
y_semi = y_train.copy()
y_semi[unlabeled_mask] = -1
 
base = LogisticRegression(max_iter=1000)
 
# threshold 控制伪标签进入训练集的置信度门槛。
# 门槛过低会把错误伪标签放大,门槛过高则利用不了未标注数据。
model = SelfTrainingClassifier(
    estimator=base,
    threshold=0.95,
    max_iter=10,
)
 
model.fit(X_train, y_semi)
pred = model.predict(X_valid)
print(classification_report(y_valid, pred))

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sklearn.semi_supervised import LabelSpreading
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
 
label_pipe = Pipeline(
    steps=[
        # 图半监督依赖样本距离,标准化能避免大尺度特征主导相似度图。
        ("scale", StandardScaler()),
        # kernel="rbf" 根据样本间距离构图;gamma 控制近邻影响范围。
        # alpha 越大,模型越相信传播后的软标签;越小,越贴近初始标签。
        ("label", LabelSpreading(kernel="rbf", gamma=0.2, alpha=0.2, max_iter=30)),
    ]
)
 
# y_semi 同样使用 -1 标记未标注样本。
label_pipe.fit(X_train, y_semi)
pred = label_pipe.predict(X_valid)

弱监督和主动学习要单独记录标签来源。每个样本应能追溯是人工标签、规则标签、伪标签、模型标签还是冲突融合标签。没有标签血缘,后续发现模型异常时很难判断问题来自数据、规则还是训练。

类别不均衡:imbalanced-learn

欺诈、风控、故障告警、罕见病识别等任务经常面临少数类极少的问题。重采样必须放在交叉验证或训练 pipeline 内部,只对训练 fold 生效;如果先对全量数据做 SMOTE 再切分,验证集会被合成样本污染。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler
 
imb_pipe = ImbPipeline(
    steps=[
        # 先标准化再做 SMOTE,保证近邻搜索不被大尺度特征支配。
        ("scale", StandardScaler()),
        # SMOTE 只在 fit 阶段生成少数类合成样本,验证 fold 不会被重采样。
        ("smote", SMOTE(k_neighbors=5, random_state=42)),
        # LogisticRegression 读取重采样后的训练 fold,验证 fold 保持原始分布。
        ("clf", LogisticRegression(max_iter=1000)),
    ]
)
 
# StratifiedKFold 保持每折标签比例,适合极不均衡分类。
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(
    imb_pipe,
    X,
    y,
    cv=cv,
    # average_precision 对少数类召回和精度更敏感,适合欺诈/告警类任务。
    scoring="average_precision",
)
print("PR-AUC:", scores.mean(), scores.std())

类别极不均衡时,PR-AUC 通常比 ROC-AUC 更敏感。ROC-AUC 可能在负样本极多时显得很好,但实际正例召回和精度并不满足业务要求。

无监督学习:聚类、降维与异常检测

scikit-learn 的无监督工具常用于数据探索、样本分桶、异常点发现和可视化前处理。生产系统中要谨慎解释聚类标签:聚类编号没有天然语义,通常需要人工命名、业务校验或后续监督模型承接。

任务 常用类 工程用途
降维 PCA、 TruncatedSVD 压缩高维特征、可视化前处理、稀疏文本特征降维。
聚类 KMeans、 MiniBatchKMeans、 DBSCAN、 OPTICS、 Birch、 AgglomerativeClustering、 SpectralClustering 样本分桶、用户分群、语料探索、去重前的粗聚合。
异常检测 IsolationForest、 OneClassSVM、 LocalOutlierFactor 离群样本筛查、数据质量检查、告警候选召回。
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from sklearn.cluster import KMeans, MiniBatchKMeans, OPTICS
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
 
cluster_pipe = Pipeline(
    steps=[
        ("scale", StandardScaler()),
        # PCA 先压缩噪声维度,减少 KMeans 被高维噪声牵引。
        ("pca", PCA(n_components=20, random_state=42)),
        # n_clusters 需要由轮廓系数、业务分群可解释性和人工验收共同决定。
        ("kmeans", KMeans(n_clusters=8, n_init="auto", random_state=42)),
    ]
)
cluster_id = cluster_pipe.fit_predict(X_numeric)
 
# MiniBatchKMeans 用小批量近似全量 KMeans,适合样本量很大的粗分群。
mini_kmeans = MiniBatchKMeans(n_clusters=50, batch_size=4096, random_state=42)
large_cluster_id = mini_kmeans.fit_predict(X_numeric)
 
# OPTICS 能处理不同密度的簇,并输出噪声点,适合先探索簇结构。
optics_id = OPTICS(min_samples=20, xi=0.05).fit_predict(X_numeric)
 
detector = Pipeline(
    steps=[
        ("scale", StandardScaler()),
        # contamination 是预期异常比例,应该来自业务验收或历史告警比例。
        ("iso", IsolationForest(contamination=0.02, random_state=42)),
    ]
)
is_normal = detector.fit_predict(X_numeric)
# IsolationForest 约定 -1 表示异常候选,1 表示正常样本。
anomaly_mask = is_normal == -1
无监督算法扩展:GMM、HDBSCAN、t-SNE、UMAP 与图社区

ref-2 的无监督算法版图比 KMeans/DBSCAN/PCA 更宽。工程上可以按任务拆:GMM/KDE 用于密度估计,HDBSCAN 用于变密度聚类,t-SNE/UMAP 用于可视化与局部结构展示,Leiden/Louvain 用于图社区发现。

算法 / 工具 Python 入口 工程用途
GMM sklearn.mixture.GaussianMixture 软聚类、密度估计、异常候选分数。
KDE sklearn.neighbors.KernelDensity 非参数密度估计,小中规模异常分析。
MiniBatchKMeans sklearn.cluster.MiniBatchKMeans 大样本 KMeans 近似训练,适合先做粗聚类。
OPTICS sklearn.cluster.OPTICS 变密度聚类探索,不需要提前指定全局半径。
Birch sklearn.cluster.Birch 层次化增量聚类,适合较大样本的预聚合。
SpectralClustering sklearn.cluster.SpectralClustering 基于相似图的非凸簇发现,样本量大时成本较高。
HDBSCAN hdbscan.HDBSCAN 密度不均、簇数未知、含噪声样本的聚类任务。
t-SNE sklearn.manifold.TSNE 局部结构可视化,不适合作为稳定生产特征。
UMAP umap.UMAP 可视化、近邻结构探索、部分场景下的低维表示。
Leiden / Louvain igraph、 leidenalg、 networkx 基于 kNN 图或关系图的社区发现。
Shell
1
2
# scikit-learn 覆盖 GMM、KDE、TSNE;HDBSCAN/UMAP/Leiden 需要额外库。
pip install -U scikit-learn hdbscan umap-learn igraph leidenalg networkx

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import KernelDensity
from sklearn.manifold import TSNE
 
# GMM 给出每个样本属于各高斯成分的概率,适合软聚类和密度分析。
gmm = GaussianMixture(n_components=8, covariance_type="full", random_state=42)
# fit_predict 返回最可能的成分编号;predict_proba 可得到软归属概率。
cluster_prob = gmm.fit_predict(X_numeric)
# score_samples 是每个样本的对数密度,值越低越像异常候选。
log_density = gmm.score_samples(X_numeric)
 
# KDE 估计连续密度;带宽 bandwidth 是最关键超参数。
kde = KernelDensity(kernel="gaussian", bandwidth=0.5)
kde.fit(X_numeric)
anomaly_score = -kde.score_samples(X_numeric)
 
# t-SNE 主要用于二维可视化,结果对随机种子和超参数敏感。
# perplexity 近似控制局部邻域大小,样本量变化时需要重新调。
vis2d = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(X_numeric)

Python
1
2
3
4
5
6
7
8
9
10
11
import hdbscan
import umap
 
# UMAP 先把高维 embedding 压到较低维,便于聚类和可视化探索。
# n_neighbors 控制局部结构范围,min_dist 控制可视化紧凑程度。
reducer = umap.UMAP(n_neighbors=30, min_dist=0.05, random_state=42)
X_umap = reducer.fit_transform(embeddings)
 
# HDBSCAN 不要求预先指定簇数,并能把低密度点标成 -1 噪声。
clusterer = hdbscan.HDBSCAN(min_cluster_size=30, metric="euclidean")
cluster_id = clusterer.fit_predict(X_umap)
主题模型:Latent Dirichlet Allocation

主题模型里的 LDA 指隐含狄利克雷分布(Latent Dirichlet Allocation),和前文线性判别分析(Linear Discriminant Analysis)的缩写相同。工程上它通常接在词袋或词频矩阵之后,用来发现语料中的潜在主题,并输出“文档-主题分布”和“主题-词分布”。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer
 
# LDA 主题模型需要非负词频计数;这里使用 CountVectorizer,不使用 TF-IDF。
vectorizer = CountVectorizer(
    max_features=50000,
    min_df=5,
    max_df=0.8,
    stop_words="english",
)
X_counts = vectorizer.fit_transform(documents)
 
lda_topic = LatentDirichletAllocation(
    # n_components 是主题数量,需要结合困惑度、主题可解释性和人工验收选择。
    n_components=20,
    learning_method="online",
    batch_size=1024,
    random_state=42,
)
# doc_topic 的每一行是一个文档在各主题上的概率分布。
doc_topic = lda_topic.fit_transform(X_counts)
 
# components_ 的每一行对应一个主题,每一列对应词表中的一个词。
terms = vectorizer.get_feature_names_out()
for topic_id, weights in enumerate(lda_topic.components_[:3]):
    top_ids = weights.argsort()[-10:][::-1]
    top_terms = [terms[i] for i in top_ids]
    print(topic_id, top_terms)
序列标注与概率图模型工程库

HMM、CRF、MEMM、结构化感知机和概率图模型在 ref-2 中属于经典序列建模与结构化预测。ref-6 的深度学习章节已经有 torchcrf,这里补工程库边界:HMM 可用 hmmlearn 做经典序列隐状态建模;CRF 可用 sklearn-crfsuite 做特征模板路线,也可用 torchcrf 接深度 encoder;贝叶斯网络和因子图可看 pgmpy、 pomegranate、Pyro/NumPyro。

模型 / 方法 工程入口 使用边界
HMM hmmlearn、 pomegranate 隐状态序列、简单语音/行为状态、金融 regime;特征表达能力弱于深度模型。
CRF sklearn-crfsuite、 torchcrf 序列标签全局解码,适合 NER、分词、词性标注等标签转移约束明显的任务。
MEMM / 结构化感知机 seqlearn、pystruct、教学代码、自定义结构化预测 工业新项目较少直接选用,更多作为理解 CRF、Viterbi 和结构化学习的过渡。
贝叶斯网络 / 因子图 pgmpy、 pomegranate 结构化不确定性、可解释概率依赖、小中规模推断。
可编程概率模型 Pyro、NumPyro 变分推断、MCMC、贝叶斯深度学习和复杂潜变量模型。
Shell
1
2
# 经典序列与概率图模型常用扩展库。
pip install -U hmmlearn sklearn-crfsuite pgmpy pomegranate pyro-ppl numpyro

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
from hmmlearn.hmm import GaussianHMM
 
# X 是连续观测序列,例如行为特征、声学特征或传感器特征。
# lengths 告诉 HMM 多条序列在拼接矩阵中的边界。
X = np.random.randn(300, 4).astype(np.float64)
lengths = [100, 120, 80]
 
# n_components 是隐状态数量,需要结合业务解释和验证集选择。
model = GaussianHMM(n_components=3, covariance_type="diag", random_state=42)
model.fit(X, lengths=lengths)
 
# predict 使用 Viterbi 找最可能的隐状态路径。
hidden_state = model.predict(X, lengths=lengths)
统计建模:statsmodels

statsmodels 更适合需要解释系数、置信区间、p 值和统计报告的场景。它常用于分析报告、经济/金融建模、A/B 分析和可解释性要求高的线性模型。生产预测时,scikit-learn pipeline 的工程封装通常更方便;统计解释时,statsmodels 的 summary 更直接。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import statsmodels.formula.api as smf
 
# statsmodels 的公式接口要求把标签和解释变量放在同一个 DataFrame 中。
train_df = df[["label", "age", "income", "country", "device"]].dropna()
 
model = smf.logit(
    # C(country) 和 C(device) 表示把类别变量展开为虚拟变量。
    formula="label ~ age + income + C(country) + C(device)",
    data=train_df,
)
# fit 做极大似然估计,返回带统计检验信息的结果对象。
result = model.fit()
 
# summary 给出系数、标准误、z 值和置信区间,适合做统计分析报告。
print(result.summary())
 
# predict 输出正类概率,可继续接阈值策略或校准分析。
train_df["score"] = result.predict(train_df)
自动调参:Optuna

Optuna 适合在验证集可靠、搜索预算明确时做自动调参。调参目标应直接对应业务指标,例如 AUC、PR-AUC、F1、RMSE 或 NDCG。搜索空间不宜过宽;先用人工经验给出合理边界,再让 Optuna 在边界内搜索。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import optuna
from lightgbm import LGBMClassifier, early_stopping
from sklearn.metrics import roc_auc_score
 
def objective(trial):
    params = {
        # n_estimators 给足上限,实际轮数由 early stopping 选择。
        "n_estimators": 3000,
        # learning_rate 跨数量级搜索,用 log=True 更符合调参经验。
        "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.1, log=True),
        # num_leaves 决定单棵树可表达的叶子数量,是 LightGBM 的核心容量参数。
        "num_leaves": trial.suggest_int("num_leaves", 31, 255),
        # min_child_samples 越大,每个叶子需要更多样本,模型越保守。
        "min_child_samples": trial.suggest_int("min_child_samples", 10, 200),
        # 行采样和列采样用于控制方差,也能降低训练时间。
        "subsample": trial.suggest_float("subsample", 0.6, 1.0),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.6, 1.0),
        "random_state": 42,
    }
    model = LGBMClassifier(**params)
    model.fit(
        X_train,
        y_train,
        eval_set=[(X_valid, y_valid)],
        eval_metric="auc",
        # 每个 trial 内部也要早停,避免无效参数浪费完整训练轮数。
        callbacks=[early_stopping(100, verbose=False)],
    )
    proba = model.predict_proba(X_valid)[:, 1]
    # objective 的返回值必须和 direction 一致;这里越大越好。
    return roc_auc_score(y_valid, proba)
 
# direction="maximize" 表示 Optuna 会寻找 AUC 最大的参数组合。
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)
 
print(study.best_value)
print(study.best_params)
模型解释:SHAP

SHAP 常用于解释树模型预测:全局看哪些特征最重要,局部看某个样本为什么被打成高风险。解释代码应和训练时的特征列顺序严格一致,否则解释结果会错位。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import shap
 
# TreeExplainer 针对树模型做过优化,适合 XGBoost/LightGBM/CatBoost。
explainer = shap.TreeExplainer(lgbm)
# 解释样本不宜过大;通常抽取验证集子集做离线报告。
sample = X_valid.iloc[:1000]
# shap_values 的列顺序必须和训练矩阵特征顺序一致。
shap_values = explainer.shap_values(sample)
 
# summary_plot 适合离线分析报告;线上服务通常只保存数值结果和 top features。
shap.summary_plot(shap_values, sample, show=False)
 
# 单样本解释适合排查某个高风险预测由哪些特征推动。
row = X_valid.iloc[[0]]
row_values = explainer.shap_values(row)
print(row_values)

SHAP 值表示特征对模型输出的贡献分解。它解释的是模型行为,不自动等价于现实因果关系。若业务需要因果结论,还需要实验设计、因果图或反事实分析。

模型持久化与交付

经典 ML 模型上线时应保存完整 pipeline,最终 estimator 只是其中一部分。否则线上会丢失缺失值填补、类别编码、标准化、文本向量化等预处理状态。joblib 是最常见的 Python 内部持久化方式;跨团队或不可信来源加载时,需要额外考虑安全边界。

Python
1
2
3
4
5
6
7
8
9
import joblib
 
# 保存完整 pipeline,包含预处理器、特征映射和模型参数。
joblib.dump(pipe, "model_pipeline.joblib")
 
# 加载方必须来自可信制品仓库;pickle/joblib 反序列化不适合加载未知来源文件。
loaded = joblib.load("model_pipeline.joblib")
score = loaded.predict_proba(X_valid.iloc[:5])[:, 1]
print(score)

交付目录应同时保存模型文件、训练数据 schema、特征列顺序、标签映射、评估报告、训练配置和依赖版本。经典 ML 的线上事故很大一部分来自特征列顺序变化、类别编码变化或训练/推理缺失值处理不一致。

通用强化学习框架与 LLM RL 的边界

ref-2 中的 Q-Learning、SARSA、DQN、Policy Gradient、Actor-Critic、PPO 和多智能体 RL,属于通用强化学习算法版图。它们通常围绕环境接口、状态、动作、奖励和 episode 展开;后文的 OpenRLHF/verl 则面向语言模型后训练,把 token 生成当作动作,把 reward function 或 reward model 当作奖励来源。

框架 覆盖算法 / 能力 工程定位
Gymnasium 环境接口标准, reset / step / observation / action / reward。 定义和测试 RL 环境,常作为算法库的环境层。
Stable-Baselines3 DQN、PPO、A2C、SAC、TD3 等经典 baseline。 单机实验、教学、控制任务和中小规模 baseline。
CleanRL 单文件算法实现,覆盖 DQN、PPO、SAC 等。 阅读算法细节、复现实验、修改 loss 和训练循环。
Ray RLlib 分布式采样、分布式训练、多算法、多环境。 大规模通用 RL 和需要 Ray 调度的环境交互任务。
PettingZoo 多智能体环境 API。 多智能体博弈、协作、竞争和环境基准。
Shell
1
2
# 通用 RL 环境与 baseline。
pip install -U gymnasium stable-baselines3 cleanrl "ray[rllib]" pettingzoo

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import gymnasium as gym
from stable_baselines3 import PPO
 
# Gymnasium 环境提供 reset/step 接口。
# observation 是状态,action 是智能体动作,reward 是环境反馈。
env = gym.make("CartPole-v1")
 
# PPO 这里是通用 RL 算法,处理离散动作控制任务。
# LLM RL 中的 PPO 形式相似,但 action 变成 token,rollout 由语言模型生成。
model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    verbose=1,
)
 
# total_timesteps 表示环境交互步数,语义和 epoch 不同。
model.learn(total_timesteps=100_000)
model.save("ppo_cartpole")

通用 RL 框架适合控制、游戏、仿真、多智能体和环境交互任务。语言模型后训练通常选择 TRL、OpenRLHF、verl 或 OpenRLHF/vLLM/DeepSpeed 组合,因为它们内置 tokenizer、rollout engine、KL、reference model、reward function、长序列 batch 和 LLM checkpoint 语义。

语言模型训练框架

语言模型训练框架负责把“模型、数据、训练循环、分布式运行时、评估、保存与导出”组织成可复用的工程入口。本节只做框架地图:说明每类工具解决什么问题、适合什么场景、和后文专章的关系。涉及具体 API、完整脚本、参数语义和排障细节的内容,放到后面的 Transformers、PEFT、语言模型强化学习、DeepSpeed、vLLM 等专章中展开。

框架/工具 工程定位 选择依据
Transformers 统一模型加载、tokenizer/processor、Trainer、generate 与 Hub 交付格式。 通用 LLM/NLP 微调、推理、导出都应优先熟悉这条主线;细节见后文 Transformers 详解。
Accelerate 把自定义 PyTorch 脚本接到单卡、多卡、FSDP、DeepSpeed、混合精度等运行时。 已有训练循环,但需要减少分布式样板代码时使用;它常作为 Transformers、LLaMA-Factory、Axolotl 的底层启动与运行时胶水。
PEFT 管理 LoRA、QLoRA、IA3、Prompt Tuning 等 adapter 的注入、保存、加载、合并与多 adapter 生命周期。 大模型微调默认优先考虑;adapter 细节、merge 语义和量化边界见后文 PEFT 与微调技术详解。
TRL 提供 SFT、RewardTrainer、DPO、PPO、GRPO 等后训练 Trainer。 适合 Hugging Face 生态内做 SFT、偏好优化和中小规模 RL 算法验证;RL 训练细节见后文 语言模型强化学习。
OpenRLHF 用 Ray + DeepSpeed + vLLM 编排 actor、critic、reference、reward、rollout 多角色在线 RLHF。 当 rollout、reward、训练和推理需要拆成多个进程或多个 GPU 池时,比单脚本 Trainer 更适合。
verl 面向 LLM RL 后训练的多算法框架,强调统一资源池、HybridFlow 和 Actor-Rollout-Ref Worker。 需要在 PPO、GRPO、RLOO、ReMax、REINFORCE++ 等算法间切换,并保持一套资源管理和 rollout 架构。
Unsloth 面向单卡/少卡高效微调的工作台,常用于快速 LoRA/QLoRA 实验和本地部署验证。 适合低显存、快速试验、导出到本地推理格式;不适合替代大规模分布式训练系统。
LLaMA-Factory 用 YAML、CLI 和 WebUI 组织 SFT、DPO、RM、PPO、导出等流程。 团队希望把常见 LLM 微调流程配置化,减少手写脚本维护成本时使用。
Axolotl 现代 LLM 微调配方库,覆盖 QLoRA、FSDP、DPO、GRPO、sample packing 与 vLLM 协作。 适合复杂 YAML 配方、批量实验、多种后训练路线组合,以及希望显式控制训练语义的团队。
ModelScope 中文模型/数据 Hub、SDK、pipeline 与 Trainer 工作台。 中文生态、魔搭模型、平台化模型获取与任务 pipeline 更重要时使用;底层训练后端取决于具体模型和脚本。
GLiNER / sentence-transformers / SetFit 面向特定任务的高层框架,分别覆盖 span-based NER、embedding/reranker 训练、少样本文本分类等场景。 任务结构明确、希望少写底层训练循环时使用;它们通常仍依赖 PyTorch/Transformers 作为底座。
Lightning / MMEngine 训练流程组织框架,强调 Runner/Trainer、Hook、Callback、Logger、Checkpoint 与配置系统。 团队需要统一训练规范、可复用实验工程、CV 项目管理或复杂 hook 生命周期时使用。
Hugging Face 训练主线

这一章里的框架分处不同抽象层:底座生态、后训练系统、配方工作台、任务封装和训练流程控制层。选型时先判断自己缺的是哪一层能力,再决定引入哪个库。

抽象层 负责什么 典型库
底座生态 模型、tokenizer、dataset、metric、训练循环、adapter 和后训练 Trainer 的基础 API。 Transformers、Datasets、Evaluate、Accelerate、PEFT、TRL
RL 后训练系统 actor、critic、reference、reward、rollout、Ray、vLLM、DeepSpeed/FSDP 的多角色编排。 verl、OpenRLHF
配方工作台 把模型、数据、模板、LoRA、量化、训练参数和导出路径写进 YAML/CLI/WebUI。 Unsloth、LLaMA-Factory、Axolotl、ModelScope
任务封装 围绕 NER、embedding、reranker、少样本分类等任务封装数据格式、模型头、loss 和评估。 GLiNER、sentence-transformers、SetFit
流程控制 统一训练生命周期、配置、hook、callback、logger、checkpoint、runner 和分布式策略。 Lightning、Fabric、MMEngine、OpenMMLab

Hugging Face 生态的常见组合是 Transformers + Datasets + Evaluate + Accelerate + PEFT + TRL。Transformers 管模型与 tokenizer,Datasets 管数据表与流式读取,Evaluate 管指标,Accelerate 管运行时,PEFT 管 adapter,TRL 管 SFT、偏好优化和 RL 后训练。真实项目里很少只使用其中一个库,更多是按任务阶段组合使用。

这条主线的设计重点是统一接口: from_pretrained 负责装载, save_pretrained 负责交付, Trainer 或专用 Trainer 负责训练循环, generate 负责生成式推理。这里给出最低可用入口,后文专章再展开完整 API、工程参数和排障。

Transformers:安装与 QuickStart
Shell
1
pip install -U transformers torch

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoModelForCausalLM, AutoTokenizer
 
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
 
# tokenizer 决定文本如何切成 token,必须和模型权重来自同一套仓库。
tokenizer = AutoTokenizer.from_pretrained(model_id)
 
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # 让 Transformers 按模型配置和硬件能力选择加载精度。
    torch_dtype="auto",
    # QuickStart 阶段交给库自动放置设备;正式训练再切到显式分布式配置。
    device_map="auto",
)
 
messages = [{"role": "user", "content": "用一句话解释 LoRA。"}]
 
# chat template 把结构化 messages 渲染成模型训练时熟悉的对话文本。
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
outputs = model.generate(
    **inputs,
    # 用 max_new_tokens 约束回复长度,避免 max_length 把 prompt 长度也算进去。
    max_new_tokens=64,
)
 
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Datasets 与 Evaluate:安装与 QuickStart
Shell
1
pip install -U datasets evaluate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from datasets import load_dataset
import evaluate
 
# Datasets 负责下载、缓存、切分和列式访问数据集。
dataset = load_dataset("glue", "sst2")
 
# Evaluate 把指标对象和计算逻辑独立出来,方便训练脚本复用。
accuracy = evaluate.load("accuracy")
 
predictions = [1, 0, 1]
references = [1, 1, 1]
 
# 指标输入要和任务语义对齐;分类任务通常传类别 id。
result = accuracy.compute(predictions=predictions, references=references)
print(result)
Accelerate:安装与 QuickStart
Shell
1
2
3
pip install -U accelerate
accelerate config
accelerate test

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from accelerate import Accelerator
 
accelerator = Accelerator()
 
# prepare 会把模型、优化器和 dataloader 包进当前运行时。
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
 
for batch in train_loader:
    outputs = model(**batch)
    loss = outputs.loss
 
    # 使用 accelerator.backward,才能兼容混合精度、梯度累积和分布式后端。
    accelerator.backward(loss)
 
    optimizer.step()
    optimizer.zero_grad()

Shell
1
accelerate launch train.py
PEFT:安装与 QuickStart
Shell
1
pip install -U peft

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from peft import LoraConfig, TaskType, get_peft_model
 
peft_config = LoraConfig(
    # Causal LM 表示 decoder-only 自回归语言模型。
    task_type=TaskType.CAUSAL_LM,
    # rank 控制 LoRA 分支容量。
    r=16,
    # alpha 控制 LoRA 更新量缩放。
    lora_alpha=32,
    # dropout 只作用在 LoRA 分支,用来缓和小数据过拟合。
    lora_dropout=0.05,
    # 注入点必须匹配模型源码里的线性层名字。
    target_modules=["q_proj", "v_proj"],
)
 
# base_model 通常来自 Transformers.from_pretrained。
model = get_peft_model(base_model, peft_config)
 
# 先打印可训练参数,确认训练对象被限制在 adapter 上。
model.print_trainable_parameters()
TRL:安装与 QuickStart
Shell
1
pip install -U trl

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
 
dataset = load_dataset("trl-lib/Capybara", split="train")
 
config = SFTConfig(
    output_dir="out_sft",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=1,
)
 
trainer = SFTTrainer(
    # 可以传模型 id,也可以传已经加载好的 Transformers/PEFT 模型对象。
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=config,
    train_dataset=dataset,
)
 
trainer.train()
强化学习后训练框架

verl 和 OpenRLHF 负责完整 RL 后训练系统。它们处理 actor、critic、reference、reward model、rollout engine、资源池、权重同步和 checkpoint 链路。TRL 适合先验证算法和 reward;verl / OpenRLHF 更适合把 RLHF 扩展到多 GPU、多进程和多角色系统。

verl:安装与 QuickStart
Shell
1
2
3
git clone https://github.com/volcengine/verl.git
cd verl
pip install -e .

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
python3 -m verl.trainer.main_ppo \
  algorithm.adv_estimator=grpo \
  data.train_files=./data/train.parquet \
  data.val_files=./data/val.parquet \
  data.prompt_key=prompt \
  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
  actor_rollout_ref.actor.strategy=fsdp \
  actor_rollout_ref.rollout.name=vllm \
  actor_rollout_ref.rollout.n=4 \
  trainer.n_gpus_per_node=1 \
  trainer.nnodes=1 \
  trainer.project_name=verl_quickstart \
  trainer.experiment_name=qwen_grpo

这条命令体现了 verl 的基本结构:训练数据给 prompt,actor 用 FSDP 更新策略,rollout 用 vLLM 生成候选,GRPO 用组内奖励估计 advantage。正式项目还需要补 reward function、日志、保存策略和多节点资源配置。

OpenRLHF:安装与 QuickStart
Shell
1
pip install -U openrlhf

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
ray start --head --num-gpus 1
 
openrlhf.cli.train_ppo_ray \
  --actor.model_name_or_path ./sft_model \
  --reward.model_name_or_path ./reward_model \
  --data.prompt_dataset ./prompts.jsonl \
  --actor_learning_rate 1e-6 \
  --critic_learning_rate 5e-6 \
  --micro_train_batch_size 1 \
  --train_batch_size 16 \
  --micro_rollout_batch_size 1 \
  --rollout_batch_size 32 \
  --zero_stage 2 \
  --bf16 \
  --save_path ./openrlhf_ppo_actor

OpenRLHF 的 QuickStart 先看角色边界:actor 负责被更新的策略,critic 负责价值估计,reward model 负责打分,Ray 负责调度角色,DeepSpeed 负责训练态显存和优化器状态管理。多卡生产训练通常再接 vLLM rollout engine。

大模型微调工作台

Unsloth、LLaMA-Factory 和 Axolotl 的共同目标是把大模型微调从“手写 Python 脚本”推进到“配置化训练配方”。三者的差异在重心:Unsloth 偏单卡效率和快速导出,LLaMA-Factory 偏一站式工作台和易用配置,Axolotl 偏复杂训练配方和现代后训练组合。

工具 适合场景 边界
Unsloth 单卡/少卡 LoRA、QLoRA、快速 SFT、导出到本地推理制品。 关注效率和落地速度;复杂多角色 RLHF 与大规模分布式训练通常需要别的系统承接。
LLaMA-Factory 用 YAML/CLI/WebUI 管理 SFT、RM、DPO、PPO 和导出流程。 强项是把常见路径产品化;极端自定义训练循环仍会回到 Transformers、Accelerate、DeepSpeed 或自写脚本。
Axolotl QLoRA、FSDP、sample packing、DPO、GRPO、vLLM 协作等复杂配方。 配置能力强,但需要读懂每个字段对应的训练语义,不能把 YAML 当作黑盒。
Unsloth:安装与 QuickStart
Shell
1
pip install -U unsloth

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from unsloth import FastLanguageModel
 
model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/Meta-Llama-3.1-8B",
    # QuickStart 先给出明确上下文预算,方便估算显存。
    max_seq_length=2048,
    # 4bit 加载用于降低底座显存,训练时只更新 LoRA。
    load_in_4bit=True,
)
 
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
LLaMA-Factory:安装与 QuickStart
Shell
1
2
3
git clone https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]"

llamafactory_quickstart.yaml
YAML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# sft 表示监督微调阶段。
stage: sft
 
# 使用 LoRA adapter,避免全参更新。
finetuning_type: lora
 
# 底座模型路径或 Hub id。
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
 
# 数据集名称由 LLaMA-Factory 的 dataset_info 管理。
dataset: identity
 
# 模板必须匹配目标模型的对话格式。
template: qwen
 
# 输出目录保存 adapter、日志和配置快照。
output_dir: saves/qwen_quickstart/lora/sft
 
# QuickStart 先跑小 batch,验证链路。
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 2.0e-4
num_train_epochs: 1.0

Shell
1
llamafactory-cli train llamafactory_quickstart.yaml
Axolotl:安装与 QuickStart
Shell
1
pip install -U axolotl

axolotl_quickstart.yml
YAML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 底座模型身份。
base_model: Qwen/Qwen2.5-0.5B-Instruct
 
# tokenizer 类型显式写出,避免模型仓库默认配置不完整。
tokenizer_type: AutoTokenizer
 
# LoRA adapter 训练。
adapter: lora
 
datasets:
  # QuickStart 使用 alpaca 风格数据。
  - path: tatsu-lab/alpaca
    type: alpaca
 
# 模板决定训练文本如何拼接。
chat_template: chatml
 
# 先用短上下文跑通流程。
sequence_len: 1024
micro_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 2.0e-4
output_dir: ./outputs/axolotl_quickstart

Shell
1
axolotl train axolotl_quickstart.yml
ModelScope:安装与 QuickStart
Shell
1
pip install -U modelscope

Python
1
2
3
4
5
6
7
8
9
10
from modelscope.pipelines import pipeline
 
word_segmentation = pipeline(
    "word-segmentation",
    # 显式指定模型,避免默认模型版本变化影响结果。
    model="damo/nlp_structbert_word-segmentation_chinese-base",
)
 
result = word_segmentation("语言模型训练框架需要区分底座生态和配方工作台。")
print(result)
任务特定训练框架

任务特定框架把某类任务的模型头、损失函数、数据格式和评估方式封装起来。GLiNER 适合开放标签 NER 与 span-based NER;sentence-transformers 适合 embedding、相似度、检索和 reranker;SetFit 适合少样本文本分类。这类框架的优势是快速形成可用 baseline,代价是底层训练循环的自由度较低。

GLiNER:安装与 QuickStart
Shell
1
pip install -U gliner

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from gliner import GLiNER
 
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
 
text = "John works at Google and lives in Paris."
 
# GLiNER 的标签来自自然语言列表,适合快速扩展新实体类型。
labels = ["person", "organization", "location"]
 
entities = model.predict_entities(
    text,
    labels,
    # threshold 控制召回和精度平衡,正式任务需要按验证集调。
    threshold=0.5,
)
 
print(entities)
sentence-transformers:安装与 QuickStart
Shell
1
pip install -U sentence-transformers

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from sentence_transformers import SentenceTransformer
 
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
sentences = [
    "A dog is playing in the park.",
    "A puppy runs outside.",
]
 
# encode 输出可直接用于相似度、聚类、检索或向量数据库写入。
embeddings = model.encode(sentences, normalize_embeddings=True)
 
similarity = embeddings[0] @ embeddings[1]
print(float(similarity))
SetFit:安装与 QuickStart
Shell
1
pip install -U setfit

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments
 
dataset = load_dataset("SetFit/sst2")
 
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
 
args = TrainingArguments(
    batch_size=16,
    num_epochs=1,
)
 
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"].select(range(64)),
    eval_dataset=dataset["validation"].select(range(64)),
)
 
trainer.train()
metrics = trainer.evaluate()
print(metrics)
训练流程组织框架

Lightning、Fabric、MMEngine、OpenMMLab 这类框架关注训练工程的组织方式。它们处理 callback、hook、logger、checkpoint、runner、配置和分布式策略,让团队把研究代码、训练规范和实验产物放进统一结构。对于长期维护的 CV 项目、多模型实验平台或跨团队训练规范,这类框架比临时脚本更稳。

Lightning 与 Fabric:安装与 QuickStart
Shell
1
pip install -U lightning

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
 
class LitClassifier(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(10, 2)
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.net(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss
 
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-4)
 
trainer = L.Trainer(
    max_epochs=1,
    accelerator="auto",
    devices="auto",
)
 
trainer.fit(LitClassifier(), train_dataloaders=train_loader)

Python
1
2
3
4
5
6
7
8
from lightning import Fabric
 
fabric = Fabric(accelerator="auto", devices="auto")
fabric.launch()
 
# Fabric 保留自定义训练循环,只接管设备、精度和分布式运行时。
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
MMEngine:安装与 QuickStart
Shell
1
2
pip install -U openmim
mim install mmengine

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
 
class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet50(num_classes=2)
 
    def forward(self, imgs, labels=None, mode="loss"):
        logits = self.backbone(imgs)
 
        if mode == "loss":
            # MMEngine 的 Runner 期望 loss 模式返回 dict,便于统一日志和反传。
            return {"loss": F.cross_entropy(logits, labels)}
 
        # predict 模式把结果交给 evaluator 或可视化流程。
        return logits
OpenMMLab:安装与 QuickStart
Shell
1
2
3
pip install -U openmim
mim install mmengine
mim install mmdet

Shell
1
2
3
4
5
6
# 下载配置和 checkpoint 后,可以先用 demo 验证推理链路。
python demo/image_demo.py \
  demo/demo.jpg \
  configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py \
  checkpoints/faster_rcnn_r50_fpn_1x_coco.pth \
  --out-file result.jpg

OpenMMLab 的 QuickStart 通常围绕具体任务仓库展开,例如 MMDetection、MMSegmentation、MMPose。MMEngine 是底层训练引擎,OpenMMLab 各任务库在它之上提供模型、数据集、配置和评估流程。

训练场景代码模板

第 4 篇从理论和策略角度讨论了任务特定语言模型、embedding 训练、继续预训练、分类微调、序列标注、条件生成、生成模型 SFT、拒绝采样、DPO 和蒸馏等训练场景。本节把这些场景落到工程代码:同一个训练目标,先判断应该使用表示模型、embedding 模型还是生成模型,再选择 Transformers、sentence-transformers、SetFit、PEFT/TRL 等框架入口。

训练场景到框架入口
训练场景 主要框架 核心代码入口 产物形态
小数据集分类微调 Transformers + PEFT AutoModelForSequenceClassification、 Trainer、 LoraConfig 分类模型目录或 LoRA adapter
表示模型继续预训练 Transformers AutoModelForMaskedLM、 DataCollatorForLanguageModeling 领域化 encoder checkpoint
embedding 训练与检索微调 sentence-transformers SentenceTransformerTrainer、 MultipleNegativesRankingLoss embedding 模型目录
基于表示模型的重排训练 sentence-transformers CrossEncoder、query-document-label、hard negatives cross-encoder reranker
少样本文本分类 SetFit SetFitModel、 Trainer sentence-transformer body + 分类头
生成模型 SFT / QLoRA Transformers + PEFT + TRL SFTTrainer、 BitsAndBytesConfig、 LoraConfig LoRA adapter 或合并后的 Causal LM
拒绝采样回写 SFT vLLM / Transformers + TRL 多候选生成、规则/模型评分、JSONL 回写、 SFTTrainer 筛选后的 SFT 数据 + 新 adapter
DPO 偏好调优 TRL + PEFT DPOTrainer、chosen/rejected 数据 DPO adapter 或合并模型
Token 级序列标注 / NER Transformers AutoModelForTokenClassification、word_ids 对齐、 DataCollatorForTokenClassification token classification 模型目录
T5 / BART 条件生成 Transformers AutoModelForSeq2SeqLM、 DataCollatorForSeq2Seq、 Seq2SeqTrainer 摘要、翻译、改写、text-to-text 分类模型
冻结表示模型 + 轻量分类器 Transformers + scikit-learn encoder 特征抽取、 LogisticRegression 特征化分类 pipeline
嵌入零样本分类 sentence-transformers 标签描述向量、文本向量、余弦相似度 无需训练的标签匹配器
LLM 教师蒸馏到 Encoder-only 学生 LLM API / 本地生成 + Transformers 弱标注、置信度过滤、学生分类器微调 低延迟线上分类器
DoRA / Q-DoRA 高容量微调 PEFT + TRL LoraConfig(use_dora=True)、QLoRA 量化底座 更高容量 adapter
小数据集分类微调:Transformers + LoRA

小数据集分类任务应先限制可训练容量,再观察验证集和尾部样本。BERT、RoBERTa、DeBERTa 这类 encoder 模型适合闭集分类;LoRA 可把更新集中到注意力投影层,降低过拟合和显存压力。

Shell
1
2
3
4
5
6
# transformers 提供模型、tokenizer 和 Trainer。
# datasets 负责读取 CSV 并生成 DatasetDict。
# evaluate 提供 F1 等指标实现。
# peft 提供 LoRA adapter 注入。
# accelerate 是 Trainer 运行单卡/多卡/混合精度的底层运行时。
pip install -U transformers datasets evaluate peft accelerate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# evaluate 用来加载可复用指标,避免自己手写 F1 细节。
import evaluate
# numpy 用来把 logits 转成类别 id。
import numpy as np
# load_dataset 读取本地 CSV 并返回 train/validation 两个 split。
from datasets import load_dataset
# LoraConfig 定义 adapter;TaskType 告诉 PEFT 当前是序列分类任务。
# get_peft_model 把 LoRA adapter 挂到已有 Transformers 模型上。
from peft import LoraConfig, TaskType, get_peft_model
# AutoModelForSequenceClassification 会自动加载带分类头的 encoder。
# AutoTokenizer 保证分词规则和底座模型一致。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# DataCollatorWithPadding 在组 batch 时动态 padding。
# Trainer / TrainingArguments 负责训练循环和训练参数。
from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
 
# 底座模型 id;这里选择 DeBERTa encoder,适合闭集文本分类。
model_id = "microsoft/deberta-v3-base"
# data_files 显式声明训练集和验证集文件,DatasetDict 会生成对应 split。
dataset = load_dataset("csv", data_files={"train": "train.csv", "validation": "valid.csv"})
# 从训练集收集标签名,并排序保证 label id 可复现。
label_names = sorted(set(dataset["train"]["label"]))
# label2id 把业务标签映射成模型 loss 需要的整数类别。
label2id = {name: i for i, name in enumerate(label_names)}
# id2label 写回模型配置,推理输出时能还原成人类可读标签。
id2label = {i: name for name, i in label2id.items()}
 
# tokenizer 必须和 model_id 一致,否则 token id 会和 embedding 表错位。
tokenizer = AutoTokenizer.from_pretrained(model_id)
 
def preprocess(batch):
    # 文本列名要和业务数据保持一致;truncation 防止异常长文本撑爆 batch。
    encoded = tokenizer(batch["text"], truncation=True, max_length=256)
    # Trainer 约定监督标签字段叫 labels。
    encoded["labels"] = [label2id[x] for x in batch["label"]]
    return encoded
 
# batched=True 让 tokenizer 批量处理样本,吞吐更高。
# remove_columns 删除原始 text/label 列,只保留模型 forward 能消费的字段。
tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
model = AutoModelForSequenceClassification.from_pretrained(
    # 从同一个底座加载 encoder 和分类头初始化配置。
    model_id,
    # 分类头输出维度必须等于标签数量。
    num_labels=len(label_names),
    # 把 id -> 标签名写入 config,便于推理和保存后复用。
    id2label=id2label,
    # 把标签名 -> id 写入 config,便于 pipeline 或下游工具识别。
    label2id=label2id,
)
 
lora = LoraConfig(
    # SEQ_CLS 告诉 PEFT 当前任务是序列级分类。
    task_type=TaskType.SEQ_CLS,
    # rank 越小,可训练容量越低;小数据分类先用较小 rank 控制过拟合。
    r=8,
    # alpha 控制 LoRA 更新幅度,通常和 rank 配套调整。
    lora_alpha=16,
    # dropout 只作用在 LoRA 分支,降低小数据上过度记忆的风险。
    lora_dropout=0.05,
    # DeBERTa 注意力投影层常见命名;真实项目应先打印模块名确认。
    target_modules=["query_proj", "value_proj"],
)
# 包装后模型主体冻结,训练主要更新 LoRA 参数和必要的分类头参数。
model = get_peft_model(model, lora)
 
# 加载 macro F1 指标;长尾分类比 accuracy 更能暴露少数类退化。
metric = evaluate.load("f1")
 
def compute_metrics(eval_pred):
    # Trainer 传入的是 numpy logits 和 labels。
    logits, labels = eval_pred
    # 分类任务取最大 logit 对应的类别 id。
    preds = np.argmax(logits, axis=-1)
    # macro F1 能暴露长尾类别退化,比单看 accuracy 更稳。
    return metric.compute(predictions=preds, references=labels, average="macro")
 
args = TrainingArguments(
    # checkpoint、日志、trainer_state.json 都写入这个目录。
    output_dir="./cls_lora",
    # LoRA 可训练参数少,学习率通常高于全参微调。
    learning_rate=2e-4,
    # 训练 batch 控制显存占用;小数据也不宜开得过大。
    per_device_train_batch_size=16,
    # 验证不反传,batch 可以比训练更大。
    per_device_eval_batch_size=32,
    # 小数据任务应结合验证集早停;这里给出上限 epoch。
    num_train_epochs=5,
    # 每个 epoch 结束跑一次验证集,避免只看训练 loss。
    eval_strategy="epoch",
    # 每个 epoch 保存一次,和评估节奏对齐。
    save_strategy="epoch",
    # 训练结束自动恢复验证指标最优的 checkpoint。
    load_best_model_at_end=True,
    # Trainer 会寻找 eval_f1;它来自 compute_metrics 返回的 f1。
    metric_for_best_model="eval_f1",
    # F1 越大越好,不能按 loss 的“越小越好”逻辑处理。
    greater_is_better=True,
    # fp16 降低显存并提高吞吐;老卡或数值不稳时可关闭。
    fp16=True,
)
 
trainer = Trainer(
    # 已经注入 LoRA 的分类模型。
    model=model,
    # 训练超参数和保存/评估策略。
    args=args,
    # tokenized train split,字段应包含 input_ids/attention_mask/labels。
    train_dataset=tokenized["train"],
    # validation split 用于选 best checkpoint。
    eval_dataset=tokenized["validation"],
    # 动态 padding,避免所有样本都补到 max_length。
    data_collator=DataCollatorWithPadding(tokenizer),
    # 把 logits 转成 F1 等业务指标。
    compute_metrics=compute_metrics,
)
# 启动训练;Trainer 会自动执行评估、保存和 best model 恢复。
trainer.train()
# 保存最终可加载模型目录;如果是 PEFT 模型,产物主要是 adapter。
trainer.save_model("./cls_lora/best")

DeBERTa 的注意力层命名常见为 query_proj / value_proj;BERT/RoBERTa 常见为 query / value。实际项目应先打印模块名,再设置 target_modules。

Token 级序列标注:NER / Slot Filling

Token 级序列标注适合 NER、槽位填充、关键词边界识别等任务。它和句子分类的差异在输出形态:句子分类是一段文本一个标签,token classification 是每个 token 一个标签。工程难点集中在 word-level 标注与 subword token 之间的对齐。

Shell
1
2
3
4
# transformers 提供 token classification 模型与 Trainer。
# datasets 读取 tokens/tags 形式的数据。
# seqeval 提供实体级 precision/recall/F1。
pip install -U transformers datasets evaluate seqeval accelerate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# evaluate 加载 seqeval,按实体边界计算 NER 指标。
import evaluate
# load_dataset 读取 JSONL 中的 tokens 和 ner_tags 字段。
from datasets import load_dataset
# AutoModelForTokenClassification 加载每个 token 输出标签 logits 的模型。
# AutoTokenizer 提供 is_split_into_words 和 word_ids 对齐能力。
from transformers import AutoModelForTokenClassification, AutoTokenizer
# DataCollatorForTokenClassification 会动态 padding input_ids 和 labels。
# Trainer / TrainingArguments 负责训练循环、评估和保存。
from transformers import DataCollatorForTokenClassification, Trainer, TrainingArguments
 
# Encoder-only 模型适合高吞吐 NER。
model_id = "microsoft/deberta-v3-base"
# 数据每行应包含 tokens: List[str] 和 ner_tags: List[str]。
dataset = load_dataset("json", data_files={"train": "ner_train.jsonl", "validation": "ner_valid.jsonl"})
 
# 从训练集收集标签集合,并排序保证 id 映射可复现。
label_names = sorted({tag for row in dataset["train"]["ner_tags"] for tag in row})
# label2id 把 BIO/IOBES 字符串标签转成整数。
label2id = {name: i for i, name in enumerate(label_names)}
# id2label 写入模型配置,便于推理输出还原。
id2label = {i: name for name, i in label2id.items()}
 
# tokenizer 必须和底座模型一致。
tokenizer = AutoTokenizer.from_pretrained(model_id)
 
def tokenize_and_align_labels(batch):
    # is_split_into_words=True 表示输入已经是词列表,不让 tokenizer 再按空格猜词边界。
    tokenized = tokenizer(
        batch["tokens"],
        is_split_into_words=True,
        truncation=True,
        max_length=256,
    )
 
    aligned_labels = []
    # 逐条样本对齐,因为每条样本的 word_ids 都不同。
    for sample_index, tags in enumerate(batch["ner_tags"]):
        # word_ids 把每个 subword token 映射回原始第几个 word。
        word_ids = tokenized.word_ids(batch_index=sample_index)
        previous_word_id = None
        label_ids = []
 
        for word_id in word_ids:
            if word_id is None:
                # 特殊 token 和 padding 不参与 loss,统一设为 -100。
                label_ids.append(-100)
            elif word_id != previous_word_id:
                # 一个 word 的首个 subword 继承原始标签。
                label_ids.append(label2id[tags[word_id]])
            else:
                # 非首个 subword 忽略,避免一个实体词被重复计算 loss。
                label_ids.append(-100)
            previous_word_id = word_id
 
        aligned_labels.append(label_ids)
 
    # Trainer 约定 token 级监督字段仍叫 labels。
    tokenized["labels"] = aligned_labels
    return tokenized
 
# batched=True 批量处理;remove_columns 删除原始 tokens/tags,避免 Trainer forward 收到无关字段。
tokenized = dataset.map(tokenize_and_align_labels, batched=True, remove_columns=dataset["train"].column_names)
 
model = AutoModelForTokenClassification.from_pretrained(
    model_id,
    # 输出维度等于 BIO/IOBES 标签数量。
    num_labels=len(label_names),
    # 写入 id -> label,便于 pipeline 和保存后推理。
    id2label=id2label,
    # 写入 label -> id,便于加载后保持标签语义。
    label2id=label2id,
)
 
# seqeval 按实体边界评估,比 token accuracy 更可靠。
seqeval = evaluate.load("seqeval")
 
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # 对每个 token 取最高 logit 对应的标签。
    predictions = logits.argmax(axis=-1)
 
    true_predictions = []
    true_labels = []
    for pred_row, label_row in zip(predictions, labels):
        pred_tags = []
        gold_tags = []
        for pred_id, label_id in zip(pred_row, label_row):
            if label_id == -100:
                # -100 位置包括特殊 token、padding 和被忽略的 subword。
                continue
            pred_tags.append(id2label[int(pred_id)])
            gold_tags.append(id2label[int(label_id)])
        true_predictions.append(pred_tags)
        true_labels.append(gold_tags)
 
    scores = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        # overall_f1 是实体级 F1,适合做 best checkpoint 指标。
        "f1": scores["overall_f1"],
        "precision": scores["overall_precision"],
        "recall": scores["overall_recall"],
    }
 
args = TrainingArguments(
    # 输出模型、日志和 checkpoint。
    output_dir="./ner_deberta",
    # token classification 通常可以从 2e-5 或 3e-5 起步。
    learning_rate=3e-5,
    # 训练 batch 控制显存;长文本 NER 应适当减小。
    per_device_train_batch_size=16,
    # 验证不反传,batch 可以更大。
    per_device_eval_batch_size=32,
    # NER 数据较小时先给出上限 epoch,再按验证 F1 选最优。
    num_train_epochs=5,
    # 每轮评估一次实体级指标。
    eval_strategy="epoch",
    # 保存节奏和评估节奏对齐。
    save_strategy="epoch",
    # 训练结束恢复实体级 F1 最优 checkpoint。
    load_best_model_at_end=True,
    # compute_metrics 返回 f1,Trainer 会映射成 eval_f1。
    metric_for_best_model="eval_f1",
    # F1 越大越好。
    greater_is_better=True,
)
 
trainer = Trainer(
    # token classification 模型。
    model=model,
    # 训练参数。
    args=args,
    # 对齐后的训练集。
    train_dataset=tokenized["train"],
    # 对齐后的验证集。
    eval_dataset=tokenized["validation"],
    # 动态 padding input_ids 和 labels。
    data_collator=DataCollatorForTokenClassification(tokenizer),
    # 实体级指标计算函数。
    compute_metrics=compute_metrics,
)
 
# 启动 NER 微调。
trainer.train()
# 保存最佳 token classification 模型目录。
trainer.save_model("./ner_deberta/best")
T5 / BART 条件生成:摘要、翻译与 Text-to-Text 分类

T5、BART 这类 Encoder-Decoder 模型适合“输入文本到输出文本”的条件生成任务,包括摘要、翻译、改写、问答和 text-to-text 分类。它们的训练重点是同时处理输入侧 tokenization、输出侧 label tokenization,以及生成式评估。

Shell
1
2
3
4
# transformers 提供 Seq2SeqTrainer 和 Encoder-Decoder 模型。
# datasets 读取 source/target 数据。
# evaluate 可加载 ROUGE、BLEU 等生成任务指标。
pip install -U transformers datasets evaluate accelerate sentencepiece

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# load_dataset 读取 JSONL 中的 source 和 target 字段。
from datasets import load_dataset
# AutoModelForSeq2SeqLM 加载 T5/BART 这类条件生成模型。
# AutoTokenizer 处理输入和输出两侧文本。
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# DataCollatorForSeq2Seq 会动态 padding 输入和 labels。
# Seq2SeqTrainer 支持 predict_with_generate 生成式评估。
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
 
# T5 适合 text-to-text 任务;中文任务可换成对应中文/多语 T5。
model_id = "google/flan-t5-small"
# 数据应包含 source_text 和 target_text 两列。
dataset = load_dataset("json", data_files={"train": "seq2seq_train.jsonl", "validation": "seq2seq_valid.jsonl"})
 
# tokenizer 同时服务 encoder 输入和 decoder 标签。
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 加载条件生成模型,forward 会根据 labels 自动计算 seq2seq loss。
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
def preprocess(batch):
    # prefix 把任务写进输入,T5 常用这种方式区分摘要、翻译、分类等任务。
    inputs = ["summarize: " + text for text in batch["source_text"]]
    # 输入侧长度上限控制 encoder 成本。
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
 
    # text_target 表示 tokenizer 正在处理 decoder 目标文本。
    labels = tokenizer(text_target=batch["target_text"], max_length=128, truncation=True)
    # Trainer 约定 labels 保存 decoder 目标 token id。
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
 
# 删除原始文本列,保留模型可消费字段。
tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
 
collator = DataCollatorForSeq2Seq(
    # collator 需要 tokenizer 做动态 padding。
    tokenizer=tokenizer,
    # 传入 model 后,collator 可更好地准备 decoder_input_ids。
    model=model,
)
 
args = Seq2SeqTrainingArguments(
    # 输出 checkpoint、日志和生成评估结果。
    output_dir="./flan_t5_seq2seq",
    # 输入到输出任务通常显存开销较高,先用小 batch 起步。
    per_device_train_batch_size=4,
    # 验证生成也占显存,batch 可与训练保持一致。
    per_device_eval_batch_size=4,
    # 通过累积增加有效 batch。
    gradient_accumulation_steps=4,
    # Seq2Seq 微调常用 1e-4 到 5e-5 量级学习率。
    learning_rate=5e-5,
    # 示例跑 1 轮;正式任务按 ROUGE/BLEU/业务指标判断。
    num_train_epochs=1,
    # 评估时调用 generate,指标更贴近真实生成质量。
    predict_with_generate=True,
    # 生成摘要或标签的最大长度。
    generation_max_length=128,
    # 每个 epoch 评估一次。
    eval_strategy="epoch",
    # 每个 epoch 保存一次。
    save_strategy="epoch",
)
 
trainer = Seq2SeqTrainer(
    # 条件生成模型。
    model=model,
    # Seq2Seq 训练参数。
    args=args,
    # 训练集。
    train_dataset=tokenized["train"],
    # 验证集。
    eval_dataset=tokenized["validation"],
    # 动态 padding 输入和 labels。
    data_collator=collator,
    # tokenizer 保存进输出目录,部署时保持预处理一致。
    tokenizer=tokenizer,
)
 
# 启动条件生成微调。
trainer.train()
# 保存 T5/BART 风格的 text-to-text 模型。
trainer.save_model("./flan_t5_seq2seq/final")
表示模型继续预训练:MLM

表示模型继续预训练适合企业文档、医学、金融、法律等领域语料。训练目标通常是掩码语言模型(Masked Language Modeling, MLM):随机遮住一部分 token,让 encoder 根据上下文恢复它们。产物仍是 encoder checkpoint,后续可继续做分类、NER、检索或 reranker。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# load_dataset 负责读取纯文本语料,并产出 Hugging Face Dataset。
from datasets import load_dataset
# AutoModelForMaskedLM 加载带 MLM 预测头的 encoder。
# AutoTokenizer 保证分词规则和底座模型一致。
from transformers import AutoModelForMaskedLM, AutoTokenizer
# DataCollatorForLanguageModeling 在组 batch 时动态随机 mask token。
# Trainer / TrainingArguments 负责训练循环、日志和保存。
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
 
# 中文 BERT 是继续做中文领域 MLM 的常见起点。
model_id = "bert-base-chinese"
# text loader 会把每一行或文本块读入 text 字段。
dataset = load_dataset("text", data_files={"train": "domain_corpus.txt"})
# tokenizer 必须来自同一个 model_id,保证 token id 和 embedding 表匹配。
tokenizer = AutoTokenizer.from_pretrained(model_id)
 
def tokenize(batch):
    # return_special_tokens_mask 让 MLM collator 知道哪些 token 不能被随机 mask。
    return tokenizer(
        batch["text"],
        truncation=True,
        max_length=512,
        return_special_tokens_mask=True,
    )
 
# batched=True 批量分词,remove_columns 删除原始文本列,减少训练时无关字段。
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
# 加载带 MLM head 的 BERT;loss 会由模型根据 labels 自动计算。
model = AutoModelForMaskedLM.from_pretrained(model_id)
collator = DataCollatorForLanguageModeling(
    # collator 需要 tokenizer 来识别特殊 token、pad token 和 mask token。
    tokenizer=tokenizer,
    # mlm=True 表示使用 BERT 式随机 mask 目标,不做自回归 next-token 训练。
    mlm=True,
    # 15% 是 BERT MLM 的经典 masking 比例;领域继续预训练通常从这里开始。
    mlm_probability=0.15,
)
 
args = TrainingArguments(
    # 保存领域化 encoder checkpoint、日志和 trainer 状态。
    output_dir="./bert_domain_mlm",
    # encoder MLM 显存压力通常低于同规模 Causal LM,可从较大 batch 起步。
    per_device_train_batch_size=32,
    # 两个微步累积一次更新,有效 batch 为 64。
    gradient_accumulation_steps=2,
    # 继续预训练常用较小学习率,避免过快破坏通用表示。
    learning_rate=5e-5,
    # warmup_ratio 让前 3% step 逐步升学习率,降低初期不稳定。
    warmup_ratio=0.03,
    # 领域语料继续预训练给出上限 epoch,实际应结合下游验证决定停点。
    num_train_epochs=3,
    # 每 1000 step 保存一次,长语料训练中便于断点恢复。
    save_steps=1000,
    # 每 50 step 打日志,用来观察 MLM loss 和吞吐。
    logging_steps=50,
    # fp16 降低显存和带宽压力;数值异常时可关掉或切 bf16。
    fp16=True,
)
 
trainer = Trainer(
    # 带 MLM head 的 encoder。
    model=model,
    # 训练参数、日志和保存策略。
    args=args,
    # 分词后的领域语料。
    train_dataset=tokenized["train"],
    # MLM collator 负责动态 mask;不要在 map 阶段提前固定 mask。
    data_collator=collator,
)
# 启动继续预训练。
trainer.train()
# 保存最终领域化 encoder,后续可继续用于分类、NER、检索等任务。
trainer.save_model("./bert_domain_mlm/final")

MLM 继续预训练之后通常还要做下游验证。仅观察 MLM loss 下降不足以证明业务收益;需要在分类、NER、检索或问答验证集上确认领域表示确实改善。

Transformers 常见监督任务模板

很多 NLP 监督任务共享同一个 Trainer 骨架,差异主要在模型头、数据字段和 data collator。下面这张表把常见任务映射到工程入口,后续可以按任务替换模型类和预处理函数。

任务 模型入口 关键数据字段 训练要点
文本分类 / 回归 AutoModelForSequenceClassification text、 labels 分类看 accuracy/F1/AUC;回归设置 problem_type="regression" 并看 RMSE/MAE。
NER / 序列标注 AutoModelForTokenClassification tokens、BIO/IOBES 标签、word_ids 对齐后的 labels subword 对齐要把非首个子词设为 -100,避免重复计算 loss。
抽取式问答 AutoModelForQuestionAnswering question、context、answer start/end 长 context 需要 stride 滑窗;指标通常看 EM/F1。
摘要 / 翻译 AutoModelForSeq2SeqLM source text、target text 用 DataCollatorForSeq2Seq;评估看 ROUGE、BLEU 或业务指标。
因果语言模型 AutoModelForCausalLM 连续文本、chat messages 或 prompt/response 需要正确 mask prompt 与 padding;SFT 常交给 TRL 的 SFTTrainer。
冻结表示模型 + 轻量分类器

冻结表示模型适合标注数据少、训练成本敏感、且需要快速上线 baseline 的闭集分类。做法是先用 Encoder-only 模型抽取句向量,再用 scikit-learn 训练逻辑回归或线性分类器。它牺牲一部分端到端适配能力,换取训练快、显存低、可解释和易回滚。

Shell
1
2
3
4
5
# transformers 负责抽取 encoder 表示。
# datasets 读取 CSV 分类数据。
# scikit-learn 训练轻量分类器并评估。
# joblib 保存完整分类器制品。
pip install -U transformers datasets scikit-learn joblib torch

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# numpy 用来拼接向量和标签。
import numpy as np
# torch 提供 no_grad 和张量设备管理。
import torch
# load_dataset 读取 train/validation CSV。
from datasets import load_dataset
# joblib 保存 sklearn 分类器。
import joblib
# LogisticRegression 是轻量线性分类器,适合冻结特征 baseline。
from sklearn.linear_model import LogisticRegression
# classification_report 输出 precision/recall/F1。
from sklearn.metrics import classification_report
# AutoModel 加载没有任务头的 encoder 主干。
# AutoTokenizer 保持分词与 encoder 一致。
from transformers import AutoModel, AutoTokenizer
 
# 选择一个 encoder 表示模型;中文任务可换成 MacBERT、ModernBERT 或 mDeBERTa。
model_id = "microsoft/deberta-v3-base"
# CSV 需要 text 和 label 两列。
dataset = load_dataset("csv", data_files={"train": "train.csv", "validation": "valid.csv"})
 
# tokenizer 负责把原始文本转成 token ids。
tokenizer = AutoTokenizer.from_pretrained(model_id)
# AutoModel 只返回隐藏状态,不包含分类头。
encoder = AutoModel.from_pretrained(model_id)
# eval 模式关闭 dropout,保证特征抽取稳定。
encoder.eval()
 
# 有 GPU 就把 encoder 放到 GPU;sklearn 分类器仍在 CPU 上训练。
device = "cuda" if torch.cuda.is_available() else "cpu"
encoder.to(device)
 
def encode_texts(texts, batch_size=32):
    vectors = []
    for start in range(0, len(texts), batch_size):
        # 当前 batch 的原始文本。
        batch_texts = texts[start:start + batch_size]
        # padding=True 动态补齐当前 batch;truncation 防止超长文本撑爆显存。
        inputs = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
        # 把 tokenizer 输出搬到 encoder 所在设备。
        inputs = {k: v.to(device) for k, v in inputs.items()}
 
        # 冻结特征抽取不需要梯度,no_grad 可降低显存和计算开销。
        with torch.no_grad():
            outputs = encoder(**inputs)
 
        # DeBERTa/BERT 常取第一个 token 的隐藏状态作为句向量 baseline。
        cls_vec = outputs.last_hidden_state[:, 0]
        # sklearn 只能消费 CPU numpy 数组。
        vectors.append(cls_vec.cpu().numpy())
 
    # 把多个 batch 的向量拼成 [N, hidden_size] 特征矩阵。
    return np.concatenate(vectors, axis=0)
 
# 抽取训练和验证文本向量。
X_train = encode_texts(dataset["train"]["text"])
X_valid = encode_texts(dataset["validation"]["text"])
# 标签保持原始类别 id 或字符串;LogisticRegression 可处理离散标签。
y_train = np.array(dataset["train"]["label"])
y_valid = np.array(dataset["validation"]["label"])
 
clf = LogisticRegression(
    # max_iter 给优化器足够迭代次数,避免未收敛警告。
    max_iter=1000,
    # class_weight="balanced" 缓解类别不均衡。
    class_weight="balanced",
)
 
# 训练线性分类器;encoder 参数保持冻结。
clf.fit(X_train, y_train)
# 在验证集上输出分类报告。
pred = clf.predict(X_valid)
print(classification_report(y_valid, pred))
 
# 保存分类器;线上还必须同时固定 tokenizer/encoder 版本。
joblib.dump(clf, "frozen_encoder_logreg.joblib")
Embedding 训练:sentence-transformers

embedding 训练的核心数据形态是“哪些文本应该靠近,哪些文本应该远离”。检索任务常用 query-positive 对,并依赖 in-batch negatives;若能提供 hard negatives,训练目标会更贴近真实召回错误。

Shell
1
2
3
# sentence-transformers 提供 embedding 模型、loss 和 Trainer。
# datasets 用来构造或读取 query-positive 训练表。
pip install -U sentence-transformers datasets

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Dataset 让示例数据符合 Trainer 期望的数据集接口。
from datasets import Dataset
# SentenceTransformer 加载可输出句向量的 bi-encoder。
from sentence_transformers import SentenceTransformer
# Trainer 和 TrainingArguments 是 sentence-transformers v3 风格训练入口。
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
# MultipleNegativesRankingLoss 使用同 batch 其它 positive 作为负例。
from sentence_transformers.losses import MultipleNegativesRankingLoss
 
train_dataset = Dataset.from_dict(
    {
        # anchor 通常是 query、问题、搜索词或用户输入。
        "anchor": [
            "如何配置 DeepSpeed ZeRO-3?",
            "vLLM 的 prefix caching 有什么作用?",
        ],
        # positive 是和 anchor 语义匹配的答案或文档。
        "positive": [
            "ZeRO-3 会分片参数、梯度和优化器状态,并在计算时临时 gather。",
            "prefix caching 可以复用相同提示词前缀的 KV cache,降低 prefill 成本。",
        ],
    }
)
 
# 选择中文/多语 embedding 底座;训练后仍保存为 SentenceTransformer 目录。
model = SentenceTransformer("BAAI/bge-small-zh-v1.5")
# loss 绑定模型对象,训练时直接计算 anchor-positive 相似度矩阵。
loss = MultipleNegativesRankingLoss(model)
 
args = SentenceTransformerTrainingArguments(
    # 保存模型、日志和训练配置。
    output_dir="./embed_bge_domain",
    # 示例跑 1 个 epoch;正式任务按检索验证集决定训练轮数。
    num_train_epochs=1,
    # in-batch negatives 数量随 batch 增大,embedding 训练通常受益于较大 batch。
    per_device_train_batch_size=64,
    # embedding 微调常从 2e-5 起步,避免快速破坏原始语义空间。
    learning_rate=2e-5,
    # 前 10% step warmup,缓和训练初期相似度空间震荡。
    warmup_ratio=0.1,
    # fp16 提高吞吐并降低显存;不稳定时切 bf16 或 fp32。
    fp16=True,
    # 每 20 step 打日志,观察 loss 和训练速度。
    logging_steps=20,
    # 每个 epoch 保存一次,便于做离线检索评估。
    save_strategy="epoch",
)
 
trainer = SentenceTransformerTrainer(
    # 要训练的 bi-encoder。
    model=model,
    # 训练参数。
    args=args,
    # 必须包含 loss 所需列,这里是 anchor/positive。
    train_dataset=train_dataset,
    # 指定检索训练目标。
    loss=loss,
)
# 启动 embedding 微调。
trainer.train()
# 保存为 sentence-transformers 标准目录,可直接 encode 或写入向量库。
model.save_pretrained("./embed_bge_domain/final")

MultipleNegativesRankingLoss 会把同一 batch 里的其他 positive 当作当前 anchor 的负例,因此 batch size 直接影响负例数量。若训练数据来自点击日志或 FAQ 匹配,必须去重并过滤同义答案,避免把真实正例误当负例。

带 hard negative 的检索微调

带 hard negative 的数据通常包含三列:query、positive、negative。negative 可以来自 BM25 召回错误、旧 embedding 模型召回错误、人工构造的混淆答案或业务线上 bad case。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Dataset 构造包含 anchor/positive/negative 三列的训练样本。
from datasets import Dataset
# SentenceTransformer 加载待微调的 embedding 模型。
from sentence_transformers import SentenceTransformer
# Trainer 和参数对象负责训练循环。
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
# 这里仍使用 MNRL;部分版本会按列约定消费 hard negative。
from sentence_transformers.losses import MultipleNegativesRankingLoss
 
train_dataset = Dataset.from_dict(
    {
        # anchor 是用户查询。
        "anchor": ["如何保存 PEFT adapter?"],
        # positive 是业务上应被召回的正确答案。
        "positive": ["使用 save_pretrained 保存 adapter_config 和 adapter_model。"],
        # negative 是语义接近但不应作为答案的困难负例。
        "negative": ["使用 torch.save 保存整个 Python 对象会带来可移植性和安全问题。"],
    }
)
 
# 从已有中文 embedding 模型继续训练。
model = SentenceTransformer("BAAI/bge-small-zh-v1.5")
# loss 决定相似度学习目标。
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    # 输出目录单独区分 hard negative 版本,便于和普通版本对比。
    output_dir="./embed_with_hard_neg",
    # hard negative 训练更容易过拟合,batch 可先保守设置。
    per_device_train_batch_size=32,
    # 示例跑 2 轮,正式任务看验证集 recall/MRR/NDCG。
    num_train_epochs=2,
    # 小学习率保护原 embedding 空间。
    learning_rate=2e-5,
)
 
trainer = SentenceTransformerTrainer(
    # 待训练模型。
    model=model,
    # 训练参数。
    args=args,
    # 包含 hard negative 的数据集。
    train_dataset=train_dataset,
    # 检索排序损失。
    loss=loss,
)
# 启动训练;完成后应在真实检索集上评估。
trainer.train()
基于表示模型的重排训练:CrossEncoder

基于表示模型的重排训练对应检索系统里的第二阶段精排。这里的表示模型通常是 Encoder-only CrossEncoder,例如 BERT、DeBERTa、ModernBERT、BGE Reranker 或 MS MARCO 系列 reranker。第一阶段由 BM25、向量检索或混合检索召回 topK 候选;CrossEncoder 把 query 与每个候选文档拼接成一个输入序列,在同一次 encoder 前向里建模 token 级交互,并输出一个相关性分数。训练数据通常是 query-document-label 三元组,label 可以是人工 0/1 标签、点击转化标签、人工相关性等级,或由更强教师模型生成的软标签。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# math 用来根据训练样本量计算 warmup steps,避免手写固定步数。
import math
 
# DataLoader 负责把 InputExample 组成 mini-batch。
from torch.utils.data import DataLoader
 
# CrossEncoder 会把 query 和 document 拼在一起交给同一个 encoder 打分。
# InputExample 是 sentence-transformers 的轻量训练样本对象。
from sentence_transformers import CrossEncoder, InputExample
 
# 每一行对应一个 query-document 训练对。
# label=1.0 表示文档应被排到前面,label=0.0 表示它是负例。
# source 标出样本来源,便于后续分析 hard negative 是否覆盖真实线上错误。
train_rows = [
    {
        "query": "ZeRO 是什么?",
        "doc": "ZeRO 会把优化器状态、梯度和参数分片到多张 GPU 上,从而降低单卡显存。",
        "label": 1.0,
        "source": "gold_positive",
    },
    {
        "query": "ZeRO 是什么?",
        "doc": "Beam search 是生成阶段的候选路径搜索方法,常用于机器翻译和文本生成。",
        "label": 0.0,
        "source": "bm25_hard_negative",
    },
    {
        "query": "LoRA adapter 如何保存?",
        "doc": "PEFT 的 save_pretrained 会保存 adapter_config.json 和 adapter_model.safetensors。",
        "label": 1.0,
        "source": "gold_positive",
    },
    {
        "query": "LoRA adapter 如何保存?",
        "doc": "torch.save 可以序列化任意 Python 对象,但它不等价于标准 PEFT adapter 导出。",
        "label": 0.0,
        "source": "dense_hard_negative",
    },
]
 
# CrossEncoder.fit 消费 InputExample。
# texts[0] 是 query,texts[1] 是候选文档,label 是监督分数。
train_examples = [
    InputExample(texts=[row["query"], row["doc"]], label=row["label"])
    for row in train_rows
]
 
# 选择已经面向 reranking 预训练过的底座,可以显著降低领域微调成本。
# num_labels=1 表示每个 query-document 对只输出一个标量相关性分数。
# max_length 控制拼接后的最大 token 数,防止少数长文档撑爆显存。
model = CrossEncoder("BAAI/bge-reranker-base", num_labels=1, max_length=512)
 
# shuffle=True 打散正负样本顺序,避免连续同类样本造成梯度偏置。
# batch_size 越大吞吐越好,但 CrossEncoder 要联合编码每个文档,显存压力高于 Bi-Encoder。
loader = DataLoader(train_examples, shuffle=True, batch_size=16)
 
# warmup_steps 通常取总训练步数的 5% 到 10%。
# 这里按样本量自动计算,保证小数据集也至少有 1 个 warmup step。
epochs = 2
steps_per_epoch = math.ceil(len(loader))
warmup_steps = max(1, int(steps_per_epoch * epochs * 0.1))
 
model.fit(
    # CrossEncoder.fit 使用 DataLoader 作为训练输入。
    train_dataloader=loader,
    # 正式任务按验证集 nDCG@K、MRR@K 和线上延迟决定训练轮数。
    epochs=epochs,
    # warmup 让学习率从较小值平滑升高,降低训练初期破坏预训练表示的风险。
    warmup_steps=warmup_steps,
    # reranker 微调通常从 2e-5 起步,过大容易让模型只记住小规模标注集。
    optimizer_params={"lr": 2e-5},
    # weight_decay 抑制分类头和 encoder 权重过度放大,降低过拟合。
    weight_decay=0.01,
    # max_grad_norm 裁剪异常梯度,hard negative 很强时能减少训练抖动。
    max_grad_norm=1.0,
    # 支持 GPU 时启用 AMP,降低显存并提高吞吐;数值不稳定时切回 fp32。
    use_amp=True,
    # 保存精排模型目录,RAG 第二阶段可直接加载。
    output_path="./reranker_domain",
    # 展示进度条,便于本地实验观察训练是否卡住。
    show_progress_bar=True,
)

上线时,CrossEncoder 位于向量库召回之后。向量库先取 topK,例如 50 到 200 条;CrossEncoder 对每个 query-document 对打分;系统再按分数重排,取 topN,例如 5 到 20 条进入最终 RAG prompt 或搜索展示层。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 从训练输出目录加载领域 reranker。
reranker = CrossEncoder("./reranker_domain")
 
# query 是用户当前问题;线上通常来自搜索框、RAG 问句或推荐上下文。
query = "ZeRO-3 为什么能省显存?"
 
# candidates 是第一阶段召回结果;真实系统里通常来自 BM25、向量库或混合检索。
candidates = [
    "ZeRO-3 会分片参数、梯度和优化器状态,并在计算时临时 gather。",
    "梯度累积通过多次 forward/backward 模拟更大的 batch size。",
    "KV cache 用于复用自回归生成阶段的历史 attention 状态。",
]
 
# CrossEncoder 要逐对读取 query 和候选文档,不能像 embedding 那样提前只算文档向量。
pairs = [[query, doc] for doc in candidates]
 
# predict 返回每个 query-document 对的相关性分数。
# batch_size 控制推理吞吐;候选很长或显存较小时应调小。
scores = reranker.predict(pairs, batch_size=16, convert_to_numpy=True)
 
# zip 把文档和分数绑定;按分数从高到低排序就是重排结果。
ranked = sorted(zip(candidates, scores), key=lambda item: item[1], reverse=True)
 
# top_n 是最终进入生成模型上下文或搜索展示页的候选数量。
top_n = 2
 
# 只保留最相关的少量片段,避免后续 LLM prompt 被弱相关内容稀释。
top_docs = [doc for doc, score in ranked[:top_n]]
 
# 打印重排后的证据片段,真实服务里通常会把它们写入 RAG prompt。
for doc in top_docs:
    print(doc)

这类训练最关键的是负例质量。随机负例能让模型学会粗粒度主题区分,但很难提升精排能力;hard negatives 才能训练模型识别“主题相似但没有回答问题”“实体相似但对象不同”“时间版本不一致”“只回答部分条件”等真实线上错误。评估也应围绕排序指标展开,优先看 nDCG@K、MRR@K、Recall after rerank 和端到端回答引用命中率。二分类 accuracy 只能作为辅助指标。

少样本分类:SetFit

SetFit 适合每类只有少量样本的短文本闭集分类。它先把少量标注样本扩展成句子对,用对比学习微调 sentence-transformer body,再训练一个轻量分类头。

Shell
1
2
3
# setfit 提供少样本分类训练流程。
# datasets 用来构造少量标注样本。
pip install -U setfit datasets

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Dataset 用于构造内存中的少样本训练集。
from datasets import Dataset
# SetFitModel 加载 sentence-transformer body 和分类头。
# Trainer / TrainingArguments 负责少样本对比学习和分类头训练。
from setfit import SetFitModel, Trainer, TrainingArguments
 
train_dataset = Dataset.from_dict(
    {
        # text 是待分类的短文本。
        "text": [
            "物流很快,整体满意",
            "包装破损,客服也没有解决",
            "价格合适,还会回购",
            "收到后无法使用",
        ],
        # label 是闭集类别 id;这里 1 表示正向,0 表示负向。
        "label": [1, 0, 1, 0],
    }
)
 
# 选择多语 sentence-transformer,适合中文短文本少样本分类。
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
args = TrainingArguments(
    # SetFit 先构造句子对,batch_size 控制对比学习阶段的显存。
    batch_size=8,
    # 训练轮数;少样本任务应小心过拟合。
    num_epochs=1,
    # 每个样本生成多少对比训练对;数值越大,少样本扩增越强。
    num_iterations=20,
)
 
trainer = Trainer(
    # 待训练的 SetFit 模型。
    model=model,
    # 训练参数。
    args=args,
    # 少量有标签样本。
    train_dataset=train_dataset,
    # 用 F1 评估类别不均衡时的分类质量。
    metric="f1",
)
# 启动少样本训练。
trainer.train()
# 保存 sentence-transformer body 和分类头。
model.save_pretrained("./setfit_sentiment")

SetFit 的关键在于从少量样本中构造更多“同类靠近、异类远离”的监督关系,原始样本条数只是起点。类别边界清晰、文本较短、标签闭集时,它通常比直接微调大模型分类头更稳。

嵌入零样本分类:标签描述相似度

嵌入零样本分类把类别改写成自然语言描述,再比较输入文本向量与标签描述向量的相似度。它不需要训练分类头,适合标签临时变化、冷启动分类、弱标注和规则探索。代价是分类边界完全依赖标签描述质量与 embedding 模型能力。

Shell
1
2
3
# sentence-transformers 提供文本向量模型。
# numpy 用来做矩阵相似度计算。
pip install -U sentence-transformers numpy

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# numpy 用来计算向量相似度矩阵。
import numpy as np
# SentenceTransformer 把文本和标签描述编码到同一个向量空间。
from sentence_transformers import SentenceTransformer
 
# 选择中文/多语 embedding 模型。
model = SentenceTransformer("BAAI/bge-small-zh-v1.5")
 
# 待分类文本。
texts = [
    "包装破损,联系客服后一直没有处理。",
    "物流很快,价格也合适,下次还会买。",
]
 
# 标签描述写成自然语言,可让类别语义更充分地进入向量空间。
label_texts = {
    "negative": "这是一条负面用户评价,表达投诉、不满、损坏、失败或差评。",
    "positive": "这是一条正面用户评价,表达满意、推荐、喜欢、顺利或好评。",
}
 
# normalize_embeddings=True 让点积等价于余弦相似度。
text_vecs = model.encode(texts, normalize_embeddings=True)
# 保持标签顺序稳定,便于从相似度列还原标签名。
label_names = list(label_texts.keys())
# 编码标签描述,得到 [num_labels, hidden_size] 矩阵。
label_vecs = model.encode([label_texts[name] for name in label_names], normalize_embeddings=True)
 
# 相似度矩阵形状是 [num_texts, num_labels]。
scores = text_vecs @ label_vecs.T
# 每条文本选择相似度最高的标签。
best_label_ids = np.argmax(scores, axis=1)
 
for text, label_id, row_scores in zip(texts, best_label_ids, scores):
    # 取出预测标签名。
    label = label_names[int(label_id)]
    # 同时打印分数,便于人工判断阈值是否需要调整。
    print(text, label, row_scores.tolist())

这条路线适合快速建立标签体系,但正式上线前应补一小批标注验证集,检查标签描述是否引入偏差。若两个类别语义非常接近,直接训练 SetFit 或 encoder 分类头通常更稳。

无监督 embedding:TSDAE 风格训练

无监督 embedding 训练适合只有领域语料、缺少人工配对数据的场景。TSDAE(Transformer-based Sequential Denoising Auto-Encoder)会破坏输入句子,再训练模型恢复原句,从而让句向量承载足够的信息用于重构。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# DataLoader 把去噪自编码样本组成 batch。
from torch.utils.data import DataLoader
# SentenceTransformer 加载 encoder;losses 提供 TSDAE 损失。
from sentence_transformers import SentenceTransformer, losses
# DenoisingAutoEncoderDataset 会对句子做扰动,形成重构训练样本。
from sentence_transformers.datasets import DenoisingAutoEncoderDataset
 
# 无监督语料只需要领域句子,不需要人工 pair 或 label。
sentences = [
    "DeepSpeed ZeRO 会把训练状态切分到多个 rank。",
    "vLLM 使用 PagedAttention 管理 KV cache。",
    "RAG 系统需要处理 chunking、embedding、retrieval 和 rerank。",
]
 
# 使用通用 MiniLM 作为初始 encoder。
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# 数据集会生成“扰动输入 -> 原句恢复”的训练目标。
train_dataset = DenoisingAutoEncoderDataset(sentences)
# shuffle=True 避免每轮固定样本顺序;batch_size 控制显存。
loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
loss = losses.DenoisingAutoEncoderLoss(
    # encoder 模型,训练目标会推动句向量保留重构所需信息。
    model,
    # decoder 初始化来源;QuickStart 直接复用同一模型名。
    decoder_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
    # 共享 encoder/decoder embedding,减少参数并保持词表一致。
    tie_encoder_decoder=True,
)
 
model.fit(
    # sentence-transformers 旧式 fit 接口使用 (DataLoader, Loss) 元组。
    train_objectives=[(loader, loss)],
    # 示例跑 1 轮;正式训练应看检索/分类验证集。
    epochs=1,
    # warmup_steps 缓和训练初期学习率冲击。
    warmup_steps=100,
    # 保存领域适配后的 embedding 模型。
    output_path="./tsdae_domain_embed",
)

TSDAE 训练后仍需用少量检索或分类验证集评估。无监督目标能做领域适配,但它不能替代 hard negative、点击日志或人工相关性标签带来的任务边界。

生成模型 QLoRA SFT:TRL + PEFT

生成模型高效微调通常用 QLoRA 起步:4-bit 加载基座权重,只训练 LoRA adapter,用 SFTTrainer 处理 chat template、监督文本拼接和标签 mask。它适合指令跟随、格式控制、轻量领域适配和风格注入。

Shell
1
2
3
4
5
6
7
# transformers 提供 Causal LM 和 tokenizer。
# datasets 读取 json/jsonl SFT 数据。
# peft 提供 LoRA adapter。
# trl 提供 SFTTrainer。
# bitsandbytes 提供 4bit 量化加载。
# accelerate 是 Trainer 的分布式/混合精度运行时。
pip install -U transformers datasets peft trl bitsandbytes accelerate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# load_dataset 读取本地 SFT JSONL。
from datasets import load_dataset
# LoraConfig 定义要训练的 LoRA adapter。
from peft import LoraConfig
# AutoModelForCausalLM 加载自回归生成模型。
# AutoTokenizer 加载对话模板和词表。
# BitsAndBytesConfig 定义 4bit 量化加载方式。
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# SFTConfig / SFTTrainer 是 TRL 的监督微调入口。
from trl import SFTConfig, SFTTrainer
 
# 指令模型底座;SFT 数据的 chat template 应与它匹配。
model_id = "Qwen/Qwen2.5-7B-Instruct"
# 数据通常是 messages、prompt/response 或 text 字段,具体取决于 TRL 版本和格式化函数。
dataset = load_dataset("json", data_files={"train": "sft_train.jsonl"})
 
# use_fast=True 优先使用 Rust tokenizer,提高批量预处理速度。
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
quant = BitsAndBytesConfig(
    # 以 4bit 形式加载 base 权重,显著降低显存。
    load_in_4bit=True,
    # NF4 是 QLoRA 常用量化格式,适合神经网络权重分布。
    bnb_4bit_quant_type="nf4",
    # 4bit 权重反量化后的计算 dtype;bf16 在新 GPU 上更稳。
    bnb_4bit_compute_dtype="bfloat16",
    # 对量化常数再量化,进一步节省显存。
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    # 加载同一个底座模型。
    model_id,
    # 把 4bit 量化配置接入模型加载流程。
    quantization_config=quant,
    # 自动把量化后的模型放到可见设备。
    device_map="auto",
)
 
peft_config = LoraConfig(
    # rank 控制 LoRA 分支容量;7B SFT 常从 8/16/32 试起。
    r=16,
    # alpha 控制 LoRA 更新幅度,通常和 rank 搭配调整。
    lora_alpha=32,
    # LoRA 分支 dropout 用于缓和小数据过拟合。
    lora_dropout=0.05,
    # 覆盖注意力投影和 FFN 投影,容量更强但训练参数更多。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    # 声明任务是自回归语言模型。
    task_type="CAUSAL_LM",
)
 
args = SFTConfig(
    # 保存 LoRA adapter、日志和 trainer 状态。
    output_dir="./qwen_sft_lora",
    # 单条训练序列最大长度,直接决定显存和上下文覆盖。
    max_length=2048,
    # 4bit + LoRA 仍可能受长上下文限制,单卡从 1 起步更稳。
    per_device_train_batch_size=1,
    # 16 个微步累积一次更新,有效 batch 更大,梯度更平滑。
    gradient_accumulation_steps=16,
    # LoRA 参数少,学习率通常高于全参微调。
    learning_rate=2e-4,
    # 前 3% step warmup,降低初期 loss 抖动。
    warmup_ratio=0.03,
    # 训练轮数上限;正式实验应按验证指标或人工评估停。
    num_train_epochs=2,
    # bf16 适合 Ampere/Hopper 等新卡;不支持时改 fp16 或 fp32。
    bf16=True,
    # 每 10 step 记录 loss 和吞吐,便于快速发现模板错误。
    logging_steps=10,
    # 每轮保存一次 adapter checkpoint。
    save_strategy="epoch",
)
 
trainer = SFTTrainer(
    # 已按 4bit 加载的 Causal LM。
    model=model,
    # SFT 训练参数。
    args=args,
    # 训练 split;字段格式要符合 TRL 的 SFT 数据约定。
    train_dataset=dataset["train"],
    # processing_class 通常传 tokenizer,用于模板渲染和分词。
    processing_class=tokenizer,
    # 让 SFTTrainer 在模型上注入 LoRA adapter。
    peft_config=peft_config,
)
# 启动 SFT;Trainer 会处理反传、累积、保存和日志。
trainer.train()
# 保存最终 adapter 或模型目录,后续可继续 DPO/RL 或 merge 导出。
trainer.save_model("./qwen_sft_lora/final")

target_modules 要按模型结构确认。Qwen/LLaMA/Mistral 常见投影名相近;BERT/DeBERTa 的投影名不同。上线前还要确认训练 chat template 与推理 chat template 完全一致。

DoRA / Q-DoRA 高容量微调

DoRA(Weight-Decomposed Low-Rank Adaptation)把权重更新拆成方向和幅度两部分,比普通 LoRA 更接近全参数微调的表达力。Q-DoRA 则把 DoRA 与量化底座结合,适合显存受限但又希望 adapter 容量更强的深领域适配、困难分类边界和高质量指令微调。

Shell
1
2
3
4
# peft 提供 use_dora 开关。
# trl 提供 SFTTrainer。
# bitsandbytes 提供 4bit 量化底座。
pip install -U transformers datasets peft trl bitsandbytes accelerate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# load_dataset 读取 SFT 数据。
from datasets import load_dataset
# LoraConfig 同时支持 LoRA 和 DoRA;use_dora=True 会切换到 DoRA 路线。
from peft import LoraConfig
# AutoModelForCausalLM 加载生成模型。
# AutoTokenizer 保持 chat template 与词表一致。
# BitsAndBytesConfig 定义 Q-DoRA 的量化底座。
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# SFTConfig / SFTTrainer 负责监督微调流程。
from trl import SFTConfig, SFTTrainer
 
# 深领域适配通常从 SFT 起点或 instruct 模型开始。
model_id = "Qwen/Qwen2.5-7B-Instruct"
# 数据格式应与前面的 SFTTrainer 保持一致。
dataset = load_dataset("json", data_files={"train": "deep_domain_sft.jsonl"})
 
# tokenizer 负责模板渲染和分词。
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
 
quant = BitsAndBytesConfig(
    # Q-DoRA 使用 4bit 底座降低显存。
    load_in_4bit=True,
    # NF4 是 QLoRA/Q-DoRA 常见量化格式。
    bnb_4bit_quant_type="nf4",
    # bf16 作为计算 dtype,兼顾速度和稳定性。
    bnb_4bit_compute_dtype="bfloat16",
    # double quant 进一步降低量化常数开销。
    bnb_4bit_use_double_quant=True,
)
 
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # 以量化底座加载,训练时只更新 DoRA adapter。
    quantization_config=quant,
    # 自动放置设备,QuickStart 阶段减少手工切分。
    device_map="auto",
)
 
dora_config = LoraConfig(
    # DoRA 仍沿用 LoRA 的低秩配置接口。
    r=32,
    # 更高 rank 给深领域任务更强 adapter 容量。
    lora_alpha=64,
    # 深领域数据也可能过拟合,保留轻度 dropout。
    lora_dropout=0.05,
    # 覆盖注意力与 FFN 投影,适合更强表达力需求。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    # 自回归语言模型任务。
    task_type="CAUSAL_LM",
    # 打开 DoRA;关闭时就是普通 LoRA。
    use_dora=True,
)
 
args = SFTConfig(
    # 单独保存 Q-DoRA 实验,避免和普通 LoRA 混淆。
    output_dir="./qwen_sft_qdora",
    # 深领域样本常需要更长上下文。
    max_length=4096,
    # 长上下文 + 7B 量化底座仍建议从 batch=1 起步。
    per_device_train_batch_size=1,
    # 梯度累积提高有效 batch。
    gradient_accumulation_steps=16,
    # DoRA 容量更强,学习率可比普通 LoRA 更保守。
    learning_rate=1e-4,
    # warmup 缓和训练早期不稳定。
    warmup_ratio=0.03,
    # 训练轮数上限;正式实验看验证集和人工评估。
    num_train_epochs=1,
    # bf16 用于支持的 GPU。
    bf16=True,
    # 记录 loss 和吞吐。
    logging_steps=10,
    # 每轮保存一次 adapter。
    save_strategy="epoch",
)
 
trainer = SFTTrainer(
    # 量化加载后的策略模型。
    model=model,
    # SFT 参数。
    args=args,
    # 深领域 SFT 数据。
    train_dataset=dataset["train"],
    # tokenizer / processor。
    processing_class=tokenizer,
    # 注入 DoRA adapter。
    peft_config=dora_config,
)
 
# 启动 Q-DoRA SFT。
trainer.train()
# 保存 DoRA adapter,后续可评估、合并或继续偏好优化。
trainer.save_model("./qwen_sft_qdora/final")
拒绝采样回写 SFT

拒绝采样微调适合答案容易自动验证的任务,例如数学题、代码单测、格式校验和结构化抽取。流程是:同一个 prompt 生成多个候选,评分器筛出最佳候选,回写成新的 SFT 数据,再继续监督训练。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# json 用于读取 prompt JSONL 和写回新 SFT JSONL。
import json
# Path 提供更清晰的文件写入接口。
from pathlib import Path
# LLM 是 vLLM 离线批量推理入口;SamplingParams 定义采样策略。
from vllm import LLM, SamplingParams
 
# 从每行 JSON 中取 prompt 字段,形成批量生成输入。
prompts = [json.loads(line)["prompt"] for line in open("prompts.jsonl")]
# 加载上一阶段 SFT 模型或 adapter 合并后的模型目录。
llm = LLM(model="./qwen_sft_lora/final")
sampling = SamplingParams(
    # 每个 prompt 生成 4 个候选,供后续评分器筛选。
    n=4,
    # temperature 控制随机性;拒绝采样需要一定多样性。
    temperature=0.7,
    # top_p 限制累积概率质量,过滤长尾低质量 token。
    top_p=0.9,
    # max_tokens 限制每个候选最大生成长度,控制成本和异常长输出。
    max_tokens=512,
)
 
def score_answer(prompt, answer):
    # 真实项目里这里通常是规则校验、单元测试、reward model 或 LLM judge。
    if "```json" in answer and answer.count("{") == answer.count("}"):
        # 返回 1.0 表示候选通过质量门槛。
        return 1.0
    # 返回 0.0 表示候选不应回写进 SFT 数据。
    return 0.0
 
# rows 收集筛选后的新 SFT 样本。
rows = []
# llm.generate 返回每个 prompt 对应的一组候选输出。
for prompt, output in zip(prompts, llm.generate(prompts, sampling)):
    # output.outputs 中每个 item 是一个候选 completion。
    candidates = [item.text for item in output.outputs]
    # 对每个候选打分,保留分数和文本。
    scored = [(score_answer(prompt, text), text) for text in candidates]
    # 选出分数最高的候选作为回写候选。
    best_score, best_answer = max(scored, key=lambda x: x[0])
    # 只把达到质量门槛的候选写回,避免把低质量生成继续蒸馏进模型。
    if best_score >= 1.0:
        # 按 chat messages 格式写回,便于 SFTTrainer 继续消费。
        rows.append({"messages": [{"role": "user", "content": prompt}, {"role": "assistant", "content": best_answer}]})
 
# 打开输出 JSONL 文件;每行是一条可继续 SFT 的 messages 样本。
with Path("sft_rejection_sampled.jsonl").open("w") as f:
    # 逐行写入,避免一次性构造巨大字符串。
    for row in rows:
        # ensure_ascii=False 保留中文,便于人工抽查和下游读取。
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

回写后的 JSONL 可以直接接入上一节的 SFTTrainer。拒绝采样的关键风险是分布变窄:筛选过严会让模型只学习少数高分模板,因此应保留原始 SFT 数据的一部分,避免输出风格和覆盖面坍缩。

DPO 偏好调优:TRL

DPO 用 prompt、chosen、rejected 三元组训练生成模型偏向更优回答。它通常接在 SFT 之后,适合已有偏好数据、但暂时不想建立完整 reward model + PPO 链路的场景。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# load_dataset 读取偏好数据,通常包含 prompt/chosen/rejected。
from datasets import load_dataset
# LoraConfig 让 DPO 更新只落在 adapter 上。
from peft import LoraConfig
# AutoModelForCausalLM 加载当前策略模型和参考模型。
# AutoTokenizer 保证偏好样本按同一模板切分。
# BitsAndBytesConfig 降低 7B 模型加载显存。
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# DPOConfig / DPOTrainer 是 TRL 的偏好优化入口。
from trl import DPOConfig, DPOTrainer
 
# DPO 通常从 SFT 后模型开始;示例用同一个 instruct 底座表达结构。
model_id = "Qwen/Qwen2.5-7B-Instruct"
# preference_train.jsonl 应包含 DPO 所需的 prompt、chosen、rejected 字段。
dataset = load_dataset("json", data_files={"train": "preference_train.jsonl"})
# tokenizer 必须和策略模型/参考模型一致。
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
quant = BitsAndBytesConfig(
    # 4bit 加载降低策略模型和参考模型的显存占用。
    load_in_4bit=True,
    # NF4 是 QLoRA/DPO 常见量化类型。
    bnb_4bit_quant_type="nf4",
    # 计算 dtype 使用 bf16;硬件不支持时需要调整。
    bnb_4bit_compute_dtype="bfloat16",
)
 
model = AutoModelForCausalLM.from_pretrained(
    # 当前要更新的策略模型。
    model_id,
    # 以量化形式加载,减少显存。
    quantization_config=quant,
    # 自动分配到可见设备。
    device_map="auto",
)
ref_model = AutoModelForCausalLM.from_pretrained(
    # 参考模型通常冻结,用于 DPO 的相对概率约束。
    model_id,
    # 参考模型也量化加载,降低双模型显存压力。
    quantization_config=quant,
    # 与策略模型一样交给加载器安排设备。
    device_map="auto",
)
 
peft_config = LoraConfig(
    # LoRA rank 控制 DPO 阶段可训练容量。
    r=16,
    # alpha 控制 LoRA 更新幅度。
    lora_alpha=32,
    # dropout 防止偏好数据上过拟合。
    lora_dropout=0.05,
    # DPO 示例只覆盖注意力投影层,降低训练风险。
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    # 自回归语言模型任务。
    task_type="CAUSAL_LM",
)
 
args = DPOConfig(
    # DPO adapter、日志和 checkpoint 输出目录。
    output_dir="./qwen_dpo_lora",
    # 双模型 + 长上下文显存压力大,单卡 batch 从 1 起步。
    per_device_train_batch_size=1,
    # 通过梯度累积放大有效 batch,减少偏好梯度噪声。
    gradient_accumulation_steps=16,
    # DPO 通常用比 SFT 更保守的学习率。
    learning_rate=5e-5,
    # beta 控制偏好优化相对参考模型的约束强度。
    beta=0.1,
    # prompt + response 的总长度上限。
    max_length=2048,
    # prompt 部分长度上限;超长 prompt 会挤压回答 token 空间。
    max_prompt_length=1024,
    # bf16 降低显存并保持数值稳定。
    bf16=True,
    # 每个 epoch 保存一次,方便按验证集或人工评估挑选 checkpoint。
    save_strategy="epoch",
)
 
trainer = DPOTrainer(
    # 当前策略模型,会被 LoRA adapter 更新。
    model=model,
    # 冻结参考模型,用来计算 chosen/rejected 的相对偏好目标。
    ref_model=ref_model,
    # DPO 训练参数。
    args=args,
    # 偏好训练数据。
    train_dataset=dataset["train"],
    # tokenizer / processor,负责文本模板化和 tokenization。
    processing_class=tokenizer,
    # 在策略模型上注入 LoRA,避免全参 DPO。
    peft_config=peft_config,
)
# 启动 DPO 偏好优化。
trainer.train()
# 保存 DPO 后 adapter 或模型目录。
trainer.save_model("./qwen_dpo_lora/final")

偏好数据质量决定 DPO 上限。chosen 与 rejected 应当足够接近,才能训练细粒度偏好边界;若 rejected 过差,模型只会学习排除明显坏答案,对真实线上排序帮助有限。

LLM 教师蒸馏到 Encoder-only 学生

工业系统常用强 LLM 做弱标注、难例发现或标签归并,再把结果蒸馏到 DeBERTa、ModernBERT、MacBERT 这类 Encoder-only 学生模型上线。这样可以保留 LLM 的语义泛化能力,同时把线上延迟、吞吐和成本压回判别式小模型水平。

Shell
1
2
3
4
# openai 代表任意 LLM API 客户端,也可替换成本地 vLLM 服务。
# datasets 读取未标注文本和写回弱标注数据。
# transformers 训练学生分类模型。
pip install -U openai datasets transformers evaluate accelerate

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# json 用来读写 JSONL。
import json
# Path 提供文件写入接口。
from pathlib import Path
# OpenAI 客户端也可指向 OpenAI-compatible 的本地 vLLM 服务。
from openai import OpenAI
 
# 客户端配置应放在环境变量里,代码中不写 API key。
client = OpenAI()
 
# 候选标签由业务定义,教师模型只能从这些标签中选择。
labels = ["投诉", "咨询", "表扬", "其它"]
 
def teacher_label(text):
    # system prompt 固定输出约束,降低教师模型自由发挥。
    system = "你是文本分类标注器。只输出 JSON,字段为 label 和 confidence。"
    # user prompt 提供标签集合和待标注文本。
    user = f"候选标签:{labels}\n文本:{text}"
 
    response = client.chat.completions.create(
        # 教师模型可换成内部强模型或本地服务。
        model="gpt-4.1-mini",
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        # 低温度降低同一文本多次标注的随机性。
        temperature=0,
    )
 
    # 解析教师输出;生产系统应增加 JSON schema 校验和异常重试。
    return json.loads(response.choices[0].message.content)
 
# unlabeled.jsonl 每行包含 text 字段。
rows = []
for line in Path("unlabeled.jsonl").read_text().splitlines():
    item = json.loads(line)
    result = teacher_label(item["text"])
 
    # 只保留高置信弱标注,降低错误标签污染学生模型。
    if result["confidence"] >= 0.8 and result["label"] in labels:
        rows.append({"text": item["text"], "label": result["label"]})
 
# 写回学生模型可直接读取的弱标注训练集。
with Path("student_train.jsonl").open("w") as f:
    for row in rows:
        # ensure_ascii=False 保留中文标签和文本。
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# evaluate 提供学生模型验证指标。
import evaluate
# numpy 用于 logits -> class id。
import numpy as np
# load_dataset 读取教师生成的弱标注数据。
from datasets import load_dataset
# AutoModelForSequenceClassification 加载学生分类器。
# AutoTokenizer 保证学生模型分词一致。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# DataCollatorWithPadding / Trainer / TrainingArguments 负责训练。
from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
 
# 学生模型选择高吞吐 encoder。
student_id = "microsoft/deberta-v3-base"
# 读取弱标注数据和人工验证集;验证集必须尽量人工标注。
dataset = load_dataset("json", data_files={"train": "student_train.jsonl", "validation": "human_valid.jsonl"})
 
# 固定标签顺序,保证教师标签和学生 id 一致。
label_names = ["投诉", "咨询", "表扬", "其它"]
label2id = {name: i for i, name in enumerate(label_names)}
id2label = {i: name for name, i in label2id.items()}
 
# 加载学生 tokenizer。
tokenizer = AutoTokenizer.from_pretrained(student_id)
 
def preprocess(batch):
    # 对文本做截断和分词。
    encoded = tokenizer(batch["text"], truncation=True, max_length=256)
    # 把教师字符串标签映射成学生分类 loss 需要的整数 id。
    encoded["labels"] = [label2id[x] for x in batch["label"]]
    return encoded
 
# 删除原始字段,保留模型 forward 所需张量。
tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
 
model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    # 分类头输出维度等于标签数。
    num_labels=len(label_names),
    # 写入 id -> label,便于线上解释输出。
    id2label=id2label,
    # 写入 label -> id,便于保存后复用。
    label2id=label2id,
)
 
# macro F1 避免多数类掩盖小类退化。
metric = evaluate.load("f1")
 
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # 取最高 logit 对应类别。
    preds = np.argmax(logits, axis=-1)
    # 返回 macro F1。
    return metric.compute(predictions=preds, references=labels, average="macro")
 
args = TrainingArguments(
    # 学生模型输出目录。
    output_dir="./student_deberta_distilled",
    # 蒸馏弱标注可能有噪声,学习率先用保守值。
    learning_rate=2e-5,
    # 学生 encoder 训练 batch 可大于 LLM 微调 batch。
    per_device_train_batch_size=32,
    # 验证 batch 可更大。
    per_device_eval_batch_size=64,
    # 给出上限 epoch,最终按人工验证集选最优。
    num_train_epochs=3,
    # 每轮评估。
    eval_strategy="epoch",
    # 每轮保存。
    save_strategy="epoch",
    # 恢复人工验证集 F1 最优 checkpoint。
    load_best_model_at_end=True,
    # compute_metrics 返回 f1,对应 eval_f1。
    metric_for_best_model="eval_f1",
    # F1 越大越好。
    greater_is_better=True,
)
 
trainer = Trainer(
    # 学生分类模型。
    model=model,
    # 训练参数。
    args=args,
    # 教师弱标注训练集。
    train_dataset=tokenized["train"],
    # 人工验证集用于防止教师偏差被学生继承。
    eval_dataset=tokenized["validation"],
    # 动态 padding。
    data_collator=DataCollatorWithPadding(tokenizer),
    # 计算 F1。
    compute_metrics=compute_metrics,
)
 
# 启动学生模型训练。
trainer.train()
# 保存低延迟线上分类器。
trainer.save_model("./student_deberta_distilled/best")
训练脚本的基本组成

训练脚本的价值在于把训练过程工程化:数据输入稳定、训练状态可恢复、指标可观测、实验可对比、产物可追溯。一个可维护的训练脚本通常围绕四件事组织:训练循环、状态管理、配置入口、可观测性与评估。

最小训练循环

训练循环的目标是把“损失函数关于参数的梯度”转化为“参数更新”。在 PyTorch 中,这条链路可写成:前向得到 loss,反向计算梯度,优化器 step 更新参数。工程上再叠加三类必需机制:学习率调度、数值稳定/效率策略(累积、混合精度、梯度裁剪)、训练状态的保存与恢复。

训练核心对象
模型

模型在脚本里承担两种职责:定义参数化映射,以及提供可复现的前向路径。训练脚本中最容易被忽略的细节是模式切换与设备放置:训练时必须 model.train(),评估时必须 model.eval();参数与输入必须在同一设备与兼容精度上。

Python
1
2
3
4
5
6
7
8
9
device = "cuda"  # or "cpu"
model = MyModel(...)
model.to(device)  # 模型参数和输入必须落在同一设备上,否则前向会直接报 device mismatch
 
for batch in train_loader:
    model.train()  # 明确切回训练态,打开 dropout、batch norm 更新等训练行为
    # batch 也要搬到同一设备,避免前向时隐式拷贝或报错
    x, y = batch["x"].to(device), batch["y"].to(device)
    logits = model(x)
损失

损失函数是训练脚本的“唯一可优化目标”。工程上需要把损失拆成两层:第一层是数学定义(例如 CE/BCE/MSE);第二层是数据与张量形状约定(logits vs probabilities、label dtype、ignore_index、padding mask)。脚本里应当显式处理这层约定,避免模型输出与 loss 之间隐含转换。

Python
1
2
3
4
import torch.nn.functional as F
 
logits = model(x)                 # [B, C]
loss = F.cross_entropy(logits, y) # y: [B], dtype=torch.long
优化器

优化器把梯度转成参数更新。训练脚本里,优化器的“正确性”主要取决于三件事:参数组(parameter groups)是否分对、 zero_grad 是否用 set_to_none=True 清零、以及 step 的节奏是否与梯度累积/混合精度一致。

Python
1
2
3
4
5
6
7
8
9
10
import torch
 
# AdamW 是 Transformer/LLM 微调里最常见的默认优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.1)
 
# 把 grad 设为 None 比直接置零更省内存,也更容易暴露未被写入梯度的参数
optimizer.zero_grad(set_to_none=True)
loss.backward()                        # 反向传播只负责累计梯度,不会自动更新参数
# 真正的参数更新发生在这里;调度器通常也围绕这个节奏触发
optimizer.step()
常用优化器 安装 典型入口 脚本要点
torch.optim.AdamW 随 PyTorch 提供 torch.optim.AdamW(...) LLM/Transformer 微调的默认选择之一;建议显式设置 weight_decay;必要时拆 parameter groups 让 bias/Norm 走 0 weight_decay。
torch.optim.SGD 随 PyTorch 提供 torch.optim.SGD(...) 常用于 CNN/视觉训练;注意 momentum、nesterov 与 weight_decay 的组合。
自定义/函数式优化器 依赖实现 optimizer.step() 若使用函数式 API(functional optimizers),需要把 grad_scale/found_inf 等 AMP 信息正确传递给优化器。
scheduler

学习率调度器的工程关键在于 step 的触发时机;调度器选型通常是第二顺位。常见两类:

  • epoch 级 step:每个 epoch 结束后调用一次。
  • step 级 step:每个 optimizer update 后调用一次(常见于 warmup、OneCycle 等)。

在 PyTorch 中,调度器通常在 optimizer.step() 之后调用,避免跳过初始学习率。恢复训练时也应保存/加载 scheduler 的 state。

Python
1
2
3
4
5
6
7
8
9
10
11
12
from torch.optim.lr_scheduler import CosineAnnealingLR
 
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
# T_max 必须和你选的 step 粒度一致;这里按 epoch 计
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
 
for epoch in range(num_epochs):
    train_one_epoch(...)
    # 这里示意“完成一次 update 后再调度”;真实脚本里通常在 train step 内部调用
    optimizer.step()
    # scheduler 的节奏必须和 optimizer update 对齐,否则学习率曲线会错位
    scheduler.step()
scheduler step 粒度 典型用途 注意点
StepLR epoch 分段衰减 与里程碑 epoch 对齐,通常在 epoch 末调用。
CosineAnnealingLR epoch 或 step 平滑衰减 需要明确 T_max 的含义(epoch 数或 step 数)。
OneCycleLR step warmup + 衰减的一体化策略 必须提供 total_steps 或 epochs + steps_per_epoch;在每次 optimizer update 后 step。
ReduceLROnPlateau eval 事件驱动 指标不提升就降 LR step 时需要传入监控指标(例如 val_loss)。
warmup 与主调度器的串联

很多训练脚本核心是先经历一个短 warmup,再切到主调度器。PyTorch 原生的表达方式通常是 LinearLR / ConstantLR 配合 SequentialLR 把两段曲线串起来。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
 
# 2e-4 是 warmup 结束后真正生效的基础学习率
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
warmup = LinearLR(
    optimizer,
    start_factor=0.1,  # 第一个 update 只用 10% 基础学习率,降低开训初期数值震荡
    end_factor=1.0,    # warmup 结束后回到目标学习率
    # 这里按 optimizer.step() 次数计;用了梯度累积时不要误填成 dataloader step
    total_iters=500,
)
main = CosineAnnealingLR(
    optimizer,
    T_max=9500,        # 余弦衰减阶段的 update 数;应扣除前面 500 个 warmup update
)
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, main], # 先线性升温,再切到余弦退火
    milestones=[500],          # 第 500 次 scheduler.step() 后切到第二段调度器
)

如果 warmup 阶段希望学习率保持常数而非线性爬升,可把第一段换成 ConstantLR。核心原则不变:调度器的步数必须对齐 optimizer update,而非 micro-batch 数。

稳定性与效率机制
gradient accumulation

梯度累积(Gradient Accumulation)用多次 backward() 模拟更大的 batch:每个 micro-batch 只反向,不更新;累积到指定步数后再统一 optimizer.step()。工程上需要把 loss 除以累积步数,保证梯度尺度不被放大。

Python
1
2
3
4
5
6
7
8
9
10
11
12
grad_accum_steps = 8
optimizer.zero_grad(set_to_none=True)
 
for step, batch in enumerate(train_loader):
    loss = compute_loss(batch)
    # 必须先除累积步数,否则等价于把学习率放大了 grad_accum_steps 倍
    loss = loss / grad_accum_steps
    loss.backward()
 
    if (step + 1) % grad_accum_steps == 0:
        optimizer.step()                 # 只在累积满一个有效 batch 后更新一次参数
        optimizer.zero_grad(set_to_none=True)

在 DDP 中,默认每次 backward() 都会触发梯度同步;如果仍按上面的写法累积 micro-batch,会在前 \(N-1\) 次 micro-step 上白白做 all-reduce。更常见的工程写法是把同步推迟到最后一个 micro-step。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
grad_accum_steps = 8
optimizer.zero_grad(set_to_none=True)
 
for step, batch in enumerate(train_loader):
    # 只有最后一个 micro-step 才需要同步梯度并更新参数
    is_update_step = (step + 1) % grad_accum_steps == 0
 
    if is_update_step:
        loss = compute_loss(batch) / grad_accum_steps
        # 最后一个 micro-step 走正常 backward,DDP 会在这里执行 all-reduce
        loss.backward()
    else:
        with model.no_sync():  # 前几个 micro-step 只在本 rank 累积梯度,避免重复通信
            loss = compute_loss(batch) / grad_accum_steps
            loss.backward()
 
    if is_update_step:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

一些训练脚本会改写 model.require_backward_grad_sync 来达到相同目的。两种写法的工程语义一致:前几个 micro-step 只攒梯度,最后一步再同步。

mixed precision

混合精度(Automatic Mixed Precision, AMP)在前向与反向中对不同算子选择不同精度,提升吞吐并降低显存占用。PyTorch 推荐使用 torch.autocast 与 torch.amp.GradScaler 组合;旧的 torch.cuda.amp.autocast 已被标注为弃用入口,脚本应迁移到 torch.amp.autocast("cuda") 风格。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
 
# GradScaler 负责放大 loss,减少 fp16 下的小梯度下溢
scaler = torch.amp.GradScaler("cuda")
optimizer.zero_grad(set_to_none=True)
 
for batch in train_loader:
    with torch.amp.autocast("cuda", dtype=torch.float16):
        # autocast 让前向里的大部分算子自动选更省显存的精度执行
        loss = compute_loss(batch)
 
    # 先放大后的 loss 再反传,梯度更不容易在 fp16 中被截成 0
    scaler.scale(loss).backward()
 
    # 梯度裁剪前必须先还原真实梯度尺度,否则 max_norm 没意义
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
    scaler.step(optimizer)          # 如果本步检测到 inf/nan,GradScaler 会跳过这次更新
    scaler.update()                 # 根据本步是否溢出动态调整下一步的缩放因子
    optimizer.zero_grad(set_to_none=True)
clipping

梯度裁剪(Gradient Clipping)用于抑制梯度爆炸与异常尖峰更新。脚本里常用两种方式:按范数裁剪与按值裁剪。混合精度场景下,裁剪通常发生在 scaler.unscale_(optimizer) 之后、 scaler.step(optimizer) 之前。

训练看板里建议同时记录 grad_norm。它表示当前 update step 上所有可训练参数梯度的整体 L2 范数,用于判断“这一步模型准备更新多大”。PyTorch 的 torch.nn.utils.clip_grad_norm_ 会返回裁剪前的总梯度范数,因此可以直接把返回值写入 TensorBoard、W&B 或 MLflow。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from torch.utils.tensorboard import SummaryWriter
 
writer = SummaryWriter(log_dir="runs/grad-norm-demo")
scaler = torch.amp.GradScaler("cuda")
max_grad_norm = 1.0
 
for global_step, batch in enumerate(train_loader):
    optimizer.zero_grad(set_to_none=True)
 
    with torch.amp.autocast("cuda", dtype=torch.float16):
        # loss 是当前 micro-batch 的优化目标,后续所有梯度都从它反传得到。
        loss = compute_loss(model, batch)
 
    # fp16 训练先放大 loss,减少小梯度在半精度里下溢成 0 的概率。
    scaler.scale(loss).backward()
 
    # 裁剪和日志记录都要基于真实梯度尺度,先撤销 GradScaler 的放大。
    scaler.unscale_(optimizer)
 
    total_grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(),
        max_norm=max_grad_norm,
        norm_type=2.0,
    )
    # clip_grad_norm_ 返回裁剪前总范数,可用来观察是否频繁触发裁剪。
    writer.add_scalar("train/grad_norm", float(total_grad_norm), global_step)
    writer.add_scalar("train/loss", float(loss.detach()), global_step)
    writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_step)
 
    # 若本步出现 inf/nan,GradScaler 会跳过 optimizer.step,避免污染权重。
    scaler.step(optimizer)
    # GradScaler 根据本步溢出情况调整下一步的 loss scale。
    scaler.update()

这段代码的关键点是顺序:先 backward,再 unscale_,接着计算并裁剪梯度范数,最后执行优化器更新。若在 unscale_ 之前记录范数,看到的是被 GradScaler 放大后的数值;若在 optimizer.step() 之后记录,梯度可能已经被清理或不再代表本次更新。

命令/API/函数
torch.nn.utils.clip_grad_norm_

说明
按整体范数裁剪梯度

示例

Python
1
2
3
4
5
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
    norm_type=2.0,
)

命令/API/函数
torch.nn.utils.clip_grad_value_

说明
按绝对值范围裁剪梯度

示例

Python
1
2
3
4
torch.nn.utils.clip_grad_value_(
    model.parameters(),
    clip_value=0.5,
)
训练状态管理
checkpoint

checkpoint 的工程目标是可恢复性与可追溯性。推荐保存 state_dict,而非直接 pickle 整个模型对象。一个可恢复 checkpoint 至少包含:

  • model.state_dict()
  • optimizer.state_dict()
  • scheduler.state_dict()(如果使用)
  • GradScaler.state_dict()(如果使用 AMP)
  • 当前 epoch、global_step、最佳指标与早停计数器
  • 必要时保存 RNG 状态(CPU/CUDA)以便复现实验
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
 
ckpt = {
    "epoch": epoch,
    "global_step": global_step,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict() if scheduler else None,
    "scaler": scaler.state_dict() if scaler else None,
    "best_metric": best_metric,
    "patience": patience_counter,
    "rng_state": torch.get_rng_state(),
}
if torch.cuda.is_available():
    # 多卡恢复若要尽量复现,需要把每张卡对应的 RNG 状态一起带上。
    ckpt["cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
 
torch.save(ckpt, ckpt_path)
resume

恢复训练(resume)要求脚本严格区分两种加载:

  • 只加载权重(用于推理或 warmstart):只读 model 的 state_dict。
  • 恢复训练:除了 model,还要恢复 optimizer/scheduler/scaler 与计数器。

跨设备恢复时,应使用 map_location 控制张量落点。对于大型权重文件,PyTorch 提供了 mmap 相关建议与加载技巧,可用于降低峰值内存。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
ckpt = torch.load(ckpt_path, map_location="cpu")
 
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
if scheduler and ckpt.get("scheduler") is not None:
    scheduler.load_state_dict(ckpt["scheduler"])
if scaler and ckpt.get("scaler") is not None:
    scaler.load_state_dict(ckpt["scaler"])
 
# 续训时从下一个 epoch 开始,避免重复训练已完成的那一轮。
start_epoch = ckpt["epoch"] + 1
# global_step 常用于恢复学习率调度、日志步数和 checkpoint 命名。
global_step = ckpt["global_step"]
# best_metric 决定 best checkpoint 与 early stopping 能否无缝接上。
best_metric = ckpt.get("best_metric", None)
early stopping

早停(Early Stopping)的正确写法是“基于业务真正关心的指标触发”。分类任务通常监控 F1/Accuracy;生成任务通常监控 ROUGE/BLEU 或下游任务指标;有些任务 loss 的上升代表校准变差但决策指标仍提升,因此脚本应把 monitor 指标显式参数化,而非写死为 val_loss。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
patience = 3
best = None
bad_epochs = 0
 
for epoch in range(num_epochs):
    train_one_epoch(...)
    # monitor 指标应该和业务目标一致,例如 F1、ROUGE 或 token accuracy。
    metric = evaluate(...)
 
    if best is None or metric > best:
        best = metric
        bad_epochs = 0
        save_best_checkpoint(...)
    else:
        bad_epochs += 1
        if bad_epochs >= patience:
            break
配置系统

训练脚本的配置系统需要解决两类问题:参数入口(CLI/环境变量)与配置结构(分层配置、默认值、校验)。通用格式(例如 YAML)只承担“配置文件承载体”的角色,真正的工程收益来自:覆盖语法、层级合并、以及把配置变成强类型对象。

命令行构建
argparse
Python
1
2
3
4
5
6
import argparse
 
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=8)
args = parser.parse_args()
Click
Python
1
2
3
4
5
6
7
8
9
10
import click
 
@click.command()
@click.option("--lr", type=float, default=2e-4)
@click.option("--batch-size", type=int, default=8)
def main(lr, batch_size):
    ...
 
if __name__ == "__main__":
    main()
Typer
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Annotated
import typer
 
app = typer.Typer()
 
@app.command()
def train(
    # 直接把 CLI 元数据写进类型标注,help/校验会自动生成。
    lr: Annotated[float, typer.Option()] = 2e-4,
    # Typer 会据此生成 --batch-size 选项并做类型转换。
    batch_size: Annotated[int, typer.Option()] = 8,
):
    ...
 
if __name__ == "__main__":
    app()
配置与模式校验
Hydra

Hydra 的核心价值是“分层配置 + 命令行覆盖 + 默认输出目录管理”。训练脚本里常把超参数、数据路径、模型结构与运行参数拆分成多个 config group,再通过 overrides 组合出一次实验。

Python
1
2
3
4
5
6
7
8
9
10
import hydra
from omegaconf import DictConfig
 
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # cfg.lr, cfg.train.batch_size, cfg.model.name, ...
    ...
 
if __name__ == "__main__":
    main()
OmegaConf
Python
1
2
3
4
5
from omegaconf import OmegaConf
 
cfg = OmegaConf.load("conf/config.yaml")
cfg = OmegaConf.merge(cfg, {"train": {"batch_size": 8}})
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
Pydantic

Pydantic 的价值是把“松散字典配置”收敛为“可验证的强类型配置对象”,在脚本启动阶段就能把拼写错误与类型错误拒之门外。

Python
1
2
3
4
5
6
7
from pydantic import BaseModel, Field
 
class TrainConfig(BaseModel):
    lr: float = Field(default=2e-4, ge=0.0)
    batch_size: int = Field(default=8, ge=1)
 
cfg = TrainConfig(lr=2e-4, batch_size=8)
日志与可视化
TensorBoard

TensorBoard 的工程用法是“在训练循环中持续写入事件文件”,再用 TensorBoard UI 查询。PyTorch 提供 torch.utils.tensorboard.SummaryWriter 作为主入口。

Shell
1
2
pip install tensorboard
tensorboard --logdir runs

Python
1
2
3
4
5
6
from torch.utils.tensorboard import SummaryWriter
 
writer = SummaryWriter(log_dir="runs/exp-001")
writer.add_scalar("train/loss", loss.item(), global_step)
writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_step)
writer.flush()
实验管理与跟踪
Weights & Biases

W&B 的训练脚本集成围绕三个动作:init 建立 run,log 写入指标与超参,finish 结束 run。离线环境可使用 offline 模式把日志落盘后再同步。

Shell
1
2
pip install wandb
wandb login

Python
1
2
3
4
5
6
import wandb
 
run = wandb.init(project="exp", config={"lr": 2e-4, "batch_size": 8})
for step in range(100):
    wandb.log({"train/loss": float(loss), "train/lr": optimizer.param_groups[0]["lr"]}, step=step)
run.finish()
MLflow

MLflow Tracking 的核心是 run:在 run 上记录 params、metrics 与 artifacts。最小闭环是:设置 experiment,启动 run,上报指标,必要时启动本地 tracking server 查看 UI。

Shell
1
2
pip install mlflow
mlflow server --port 5000

Python
1
2
3
4
5
6
7
8
9
10
11
import mlflow
 
# experiment 是 MLflow UI 的第一层分组;先固定它,后续多个 run 才能按实验归档。
mlflow.set_experiment("exp")
with mlflow.start_run():
    # params 记录的是本轮训练配置;它们是后续筛选 run 的主要维度。
    mlflow.log_params({"lr": 2e-4, "batch_size": 8})
    # metric 要带 step,曲线才会按训练过程展开,而非只剩一个最终数值。
    mlflow.log_metric("train_loss", float(loss), step=global_step)
    # artifact 用来绑定 checkpoint、评估报告等二进制产物。
    mlflow.log_artifact("checkpoints/best.pt")
LLM 可观测性
Langfuse

Langfuse 在训练脚本中的典型价值是把“训练过程中的 LLM 调用、数据生成、评测调用”以 trace/span/generation 的方式串成可查询的链路,并与指标平台形成分工:W&B/MLflow 负责 run 级指标与产物,Langfuse 负责调用链与上下文。短生命周期脚本要显式 flush 或 shutdown,确保事件被发送。

Shell
1
pip install langfuse

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os
from langfuse import get_client
 
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-..."
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-..."
os.environ["LANGFUSE_BASE_URL"] = "https://cloud.langfuse.com"
 
langfuse = get_client()
 
with langfuse.start_as_current_observation(as_type="span", name="train-step") as span:
    # 训练逻辑
    # metadata 最适合挂训练步数、样本批次或 checkpoint 版本这类排障字段。
    span.update(metadata={"global_step": global_step})
 
    with langfuse.start_as_current_observation(as_type="generation", name="synth-data", model="gpt-4.1") as gen:
        # LLM 生成数据/评测逻辑
        gen.update(output="...")
 
langfuse.flush()
OpenTelemetry / OpenInference:把 LLM 调用接入统一 tracing 体系

Langfuse 适合看 prompt、generation 与评测链路;如果系统里还同时存在 HTTP 服务、数据库、队列和检索链路,就需要把 LLM span 接到统一 tracing 体系里。OpenTelemetry 负责 trace/span/export 机制,OpenInference 与 GenAI 语义约定负责把“模型名、token 用量、工具调用、retrieval 命中”等字段标准化。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
 
provider = TracerProvider()  # 统一收集服务端 span;训练/评测脚本也可以复用同一套 provider
provider.add_span_processor(
    BatchSpanProcessor(
        # OTLP 是最常见的跨平台 trace 导出协议
        OTLPSpanExporter(endpoint="http://otel-collector:4318/v1/traces")
    )
)
trace.set_tracer_provider(provider)
tracer = trace.get_tracer(__name__)
 
with tracer.start_as_current_span("rag.answer") as span:
    # 把模型身份写成标准化属性,方便跨后端统一检索
    span.set_attribute("gen_ai.request.model", "Qwen/Qwen3-0.6B")
    # token 用量是容量规划与成本分析的关键维度
    span.set_attribute("gen_ai.usage.input_tokens", 512)
    # 检索链路参数也应进 span,便于回放与归因
    span.set_attribute("retrieval.top_k", 20)

真正做故障归因时,日志、指标和 trace 应共享同一条关联键,例如 trace_id。否则只能分别看到“慢”“贵”“错”,却很难知道它们是否来自同一条请求链路。

评估与基准
通用评估指标

训练脚本的评估模块需要满足两个工程要求:可重复(同一 checkpoint 评估一致)与可对比(同一指标口径跨 run 可比较)。分类任务的 Accuracy/F1、检索任务的 Recall@K/NDCG、生成任务的 ROUGE/BLEU 与任务自定义评分,应在脚本里拆成独立的 evaluate 函数,避免训练循环与评估逻辑互相污染。

中文评估指标
rouge-chinese

rouge-chinese 提供中文场景的 ROUGE 计算实现,针对中文标点分句与 ROUGE-L 内存占用做了工程优化。训练脚本中通常把它放在验证阶段,用于摘要、生成式问答等任务的离线评估。

Shell
1
pip install rouge-chinese

Python
1
2
3
4
5
6
7
8
9
from rouge_chinese import Rouge
 
rouge = Rouge()  # 在验证阶段复用同一个 Rouge 实例,避免每个 batch 反复初始化。
hyps = ["模型生成的摘要。"]  # 系统生成文本列表;库接口按“多条样本”设计。
refs = ["参考摘要。"]       # 参考答案列表;长度需要和 hyps 对齐。
 
# avg=True 让多样本结果先聚合,再统一写入日志。
scores = rouge.get_scores(hyps, refs, avg=True)
# scores["rouge-1"]["f"], scores["rouge-2"]["f"], scores["rouge-l"]["f"]
lm-evaluation-harness

lm-evaluation-harness 是“给定一个模型后,快速在标准任务集上跑出可复现分数”的轻量评测主线。它特别适合训练完成后的 checkpoint 筛选、不同推理后端的一致口径对比、以及把本地模型服务接入公开 benchmark。到 2025 年底,这个项目的 CLI 已重构为 run / ls / validate 子命令,并支持 YAML 配置;安装也按后端拆分为可选 extra,例如 lm_eval[hf]、 lm_eval[vllm]。这类拆分很重要,因为评测机往往只需要一个后端,没必要把整个推理生态全装进去。

命令/API/函数

pip install "lm_eval[hf]"

说明

只安装 Hugging Face backend 的评测依赖。若实际跑的是 vLLM、SGLang 或本地 OpenAI-compatible 服务,应改装对应 extra,避免环境体积和依赖冲突无谓膨胀。

示例

Shell
1
2
pip install "lm_eval[hf]"
pip install "lm_eval[vllm]"

命令/API/函数

lm_eval ls

说明

列出当前环境可见的任务、任务组或模型适配器。做 CI 或大规模批评测时,这一步相当于“环境探针”:可以先确认任务名是否变化,再决定是否启动整批评测。

示例

Shell
1
lm_eval ls

命令/API/函数

lm_eval run

说明

执行评测主入口。模型适配器由 --model 决定,任务集合由 --tasks 指定; --model_args 负责把权重路径、dtype、并行参数、服务地址等 backend 特有配置串起来。对于推理阶段已经单独部署好的超大模型,官方 README 直接建议通过 OpenAI-compatible 接口接入,例如先用 vLLM 挂服务,再让 harness 用 local-completions 类适配器做评测。

示例

Shell
1
2
3
4
5
6
7
lm_eval run \
  --model hf \
  --model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,dtype=bfloat16 \
  --tasks hellaswag,arc_easy \
  --device cuda:0 \
  --batch_size auto \
  --output_path outputs/lm_eval/llama3_8b

命令/API/函数

think_end_token

说明

这是 2025 年新增的重要参数,用来裁掉 reasoning 模型显式输出的思维链尾标记。对“答案正确但带有推理痕迹”的模型,如果不做这层截断,评测器的答案提取往往会被污染,尤其是多项选择与短答案任务。

示例

Shell
1
2
3
4
5
lm_eval run \
  --model vllm \
  --model_args pretrained=/models/DeepSeek-R1-Distill-Qwen-7B,think_end_token="</think>" \
  --tasks gsm8k \
  --batch_size auto

命令/API/函数

Task YAML: doc_to_text / filter_list / metric_list

说明

lm-eval 的真正工程价值在任务 YAML。它把 prompt 模板、答案抽取、聚合与计分写成可版本化资产,而非散落在临时脚本里。对生成式任务,常见流程是“生成多个候选 -> 用 regex 或自定义 filter 抽取答案 -> 多数投票或取首个合法答案 -> 再按 metric 计分”。模型分数是否可信,往往取决于这层配置是否明确且可复刻。

示例

YAML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
task: gsm8k_cot_local
dataset_path: gsm8k
test_split: test
doc_to_text: "{{question}}\nLet's think step by step."
doc_to_target: "{{answer}}"
filter_list:
  - name: extract_answer
    filter:
      - function: regex
        regex_pattern: "####\\s*([-0-9\\.]+)"
      - function: take_first
metric_list:
  - metric: exact_match
    aggregation: mean
    higher_is_better: true

命令/API/函数

--log_samples / --use_cache / --cache_requests

说明

这组参数决定评测是否适合工程回归。 --log_samples 让错误分析有据可查;缓存参数则避免重复打 API 或重复本地推理。对付费模型、长上下文任务或 nightly regression,这些选项通常应该默认开启,而非临时想起才加。

示例

Shell
1
2
3
4
5
6
7
8
lm_eval run \
  --model local-completions \
  --model_args base_url=http://localhost:8000/v1/completions,model=local-model \
  --tasks hellaswag \
  --log_samples \
  --use_cache ./cache/lm_eval.sqlite \
  --cache_requests true \
  --output_path ./outputs/lm_eval/regression
OpenCompass

OpenCompass 更接近“评测编排平台”而非单一命令行工具。它把一次评测拆成 Configure、Inference、Evaluation、Visualization 四个阶段,适合管理“多个模型 × 多个数据集 × 多种 judge/后处理”的评测矩阵。相比 lm-evaluation-harness,OpenCompass 的长处在于配置体系更完整、任务组织更重、对 LLM-as-judge、长上下文、污染检测、推理模型评测等场景覆盖更深;代价是上手复杂度更高,更像一个独立工程而非一个轻量脚本库。

近两年的版本演进也需要单独记住。OpenCompass 在 0.4.0 之后把不少旧式配置从仓库顶层 configs/ 目录迁入包内路径,很多旧文章里的配置引用在新版本里会直接失效。到 2026 年初,OpenCompass 又补入了 CascadeEvaluator、GenericLLMEvaluator、MATHVerifyEvaluator 等更偏“复杂 judge 流水线”的组件,已经不再只是传统选择题 benchmark 的跑分脚本。

命令/API/函数

python run.py config.py

说明

OpenCompass 的入口仍然是 run.py。配置文件同时声明模型、数据集、推理器、评测器与汇总策略;一次命令会自动拆出并行子任务,分别做推理和评测,再把结果汇总成表格、CSV 和 TXT。

示例

Shell
1
python run.py configs/eval_demo.py

命令/API/函数

python run.py ... -a vllm

说明

OpenCompass 可以把原本基于 Hugging Face 的模型配置自动切到 vLLM 或 LMDeploy 推理后端,用于加速大模型评测。这一点在长上下文、数学推理、批量生成型任务里价值很高,因为推理速度往往比 judge 逻辑本身更容易成为瓶颈。

示例

Shell
1
2
python run.py configs/eval_gsm8k.py -a vllm
python run.py configs/eval_gsm8k.py -a lmdeploy

命令/API/函数

GenericLLMEvaluator / CascadeEvaluator

说明

GenericLLMEvaluator 用于“规则难以完全覆盖”的 judge 场景,例如自由文本答案、复杂事实判断、开放式回应; CascadeEvaluator 则先跑规则评测器,再把规则无法稳定判定的样本交给 LLM judge。这样做的工程意义很明确:把昂贵的 LLM-as-judge 只用在真正模糊的样本上,评测成本和延迟会明显下降。

示例

Python
1
2
3
4
5
6
7
8
9
from opencompass.evaluator import CascadeEvaluator, GenericLLMEvaluator, MATHVerifyEvaluator
 
eval_cfg = dict(
    type=CascadeEvaluator,
    evaluators=[
        dict(type=MATHVerifyEvaluator),      # 先用规则或符号级校验筛掉能直接判分的样本
        dict(type=GenericLLMEvaluator),      # 再把剩余难例交给 LLM judge
    ],
)

命令/API/函数

--mode infer|eval|viz + --reuse

说明

OpenCompass 把一次评测拆成有状态实验目录。 infer 阶段最耗时;当只调整评测器、judge 模板或汇总逻辑时,用 eval 或 viz 配合 --reuse 复用历史输出,能省掉绝大部分推理成本。这是 OpenCompass 和轻量级单次跑分脚本最本质的差别之一。

示例

Shell
1
2
3
4
opencompass --models hf_internlm2_5_1_8b_chat \
  --datasets demo_gsm8k_chat_gen \
  --mode eval \
  --reuse latest

命令/API/函数

models / datasets / work_dir

说明

OpenCompass 的主配置是 Python 文件而非简单命令拼接。工程上最常维护的三个顶层对象是 models、 datasets 与 work_dir。这样做的意义是把“模型矩阵 × 数据集矩阵 × 推理/评估配置”直接版本化,便于团队共享与回滚。

示例

Python
1
2
3
4
5
6
7
from mmengine.config import read_base
 
with read_base():
    from opencompass.configs.datasets.demo.demo_gsm8k_chat_gen import datasets
    from opencompass.configs.models.hf_internlm.hf_internlm2_5_1_8b_chat import models
 
work_dir = 'outputs/my_eval'  # 所有推理产物、评估结果与汇总报表都挂在这个实验目录下。
数据与标注资产管理
DVC

DVC 把“数据/模型产物”从 Git 中分离出来,同时保留可追溯版本。训练脚本常配合 DVC 使用两条链路:数据版本管理(dvc add/pull/push)与流水线复现(dvc.yaml + dvc repro)。

Shell
1
2
3
4
5
6
# DVC 进入项是 CLI;装好后数据版本和流水线复现都从同一套命令面进入。
pip install dvc
dvc init
dvc add data/train.jsonl
git add data/train.jsonl.dvc data/.gitignore
git commit -m "track dataset with dvc"
Label Studio

Label Studio 是标注平台。训练脚本侧通常把它当作“数据生成与质量控制”的外部系统:标注阶段产出数据,训练阶段只消费导出的标注结果。最小可运行入口是安装并启动服务。

Shell
1
2
pip install label-studio
label-studio start
分布式训练与硬件加速组件

分布式训练与硬件加速的工程工作围绕四个入口展开:进程如何启动、通信如何建立、显存如何被切分与回收、关键算子是否落在高性能 kernel。本节以“能直接跑起来”的安装、启动、API、配置与部署约束为中心,覆盖 PyTorch distributed(DP/DDP/FSDP/torchrun)、DeepSpeed(ZeRO)、Megatron-LM/Megatron Core、CUDA/cuDNN/NCCL、Triton、FlashAttention、flash-linear-attention(fla)、xFormers、bitsandbytes,以及数值精度与重算策略。

设备与并行基础
计算设备

训练代码在多卡场景里通常遵循“一进程一 GPU”的约定:每个进程只绑定一个 GPU,并通过进程间通信完成梯度同步或参数分片。这一约定直接对应 torchrun/DDP/FSDP/DeepSpeed 的默认启动方式,也决定了日志、随机数与数据采样需要按 rank 做隔离。

并行策略

工程上最常用的并行拆分有两类:

  • 数据并行(Data Parallel, DP):每张卡持有一份模型副本,各自处理不同 batch,然后同步梯度。
  • 模型并行(Model Parallel):把一个模型拆到多张卡上。常见细分是张量并行(Tensor Parallel, TP)与流水并行(Pipeline Parallel, PP)。

对于 LLM 预训练,TP/PP 往往与数据并行同时存在;对多数微调任务,数据并行 + 参数高效微调(LoRA/QLoRA)是更常见的起点。

PyTorch 分布式主线
安装与环境校验

PyTorch 分布式训练的最低要求是:PyTorch 构建启用了 distributed,并且 GPU 通信后端可用(NVIDIA 场景通常是 NCCL)。工程上先做三类校验:CUDA 运行时可用性、distributed 模块可用性、以及当前进程可否正确枚举到 GPU。

Python
1
2
3
4
5
6
7
8
import torch
import torch.distributed as dist
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("torch cuda:", torch.version.cuda)
print("distributed available:", dist.is_available())
if torch.cuda.is_available():
    print("gpu0:", torch.cuda.get_device_name(0))
DataParallel (DP)

DataParallel 是单进程、多 GPU 的封装( torch.nn.DataParallel)。它的工程局限很明确:单进程会成为瓶颈、参数分散与通信控制不够细,且与现代分布式生态(torchrun/elastic/DCP)不在同一路线上。实际工程里更常把 DP 当作“快速验证多卡可跑”的临时方案,正式训练一般直接用 DDP 或 FSDP。

Python
1
2
3
4
5
import torch
import torch.nn as nn
 
model = nn.Linear(1024, 1024).cuda()
model = nn.DataParallel(model)  # 单进程,多卡
DDP

DDP(DistributedDataParallel)是 PyTorch 数据并行主线:多进程各自持有模型副本,反向传播时进行梯度 AllReduce。工程上,DDP 的三个稳定性入口是:进程组初始化、每个进程绑定本地 GPU、以及数据采样在 rank 之间的正确切分。

启动方式(torchrun)
Shell
1
2
# 单机 8 卡(每个进程绑定 1 张 GPU)
torchrun --standalone --nproc-per-node=8 train.py

Shell
1
2
3
4
5
6
7
8
# 多机(示例:2 台机器,每台 8 卡)
# node0:
# node-rank=0 表示主节点;master_addr/master_port 必须所有节点一致。
torchrun --nnodes=2 --node-rank=0 --nproc-per-node=8 --master_addr=$MASTER_ADDR --master_port=29500 train.py
 
# node1:
# node-rank=1 表示第二台节点;其余 rendezvous 参数保持完全一致。
torchrun --nnodes=2 --node-rank=1 --nproc-per-node=8 --master_addr=$MASTER_ADDR --master_port=29500 train.py
最小 DDP 训练骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import argparse
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
    parser = argparse.ArgumentParser()
    # 同时兼容 torchrun 常见的两种 local rank 参数命名。
    parser.add_argument("--local-rank", "--local_rank", type=int, default=None)
    _ = parser.parse_args()
 
    dist.init_process_group(backend="nccl")  # NVIDIA GPU 通常用 NCCL
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    model = torch.nn.Linear(1024, 1024).cuda()
    model = DDP(model, device_ids=[local_rank])
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    x = torch.randn(8, 1024, device="cuda")
    y = model(x).sum()
    y.backward()
    opt.step()
    opt.zero_grad(set_to_none=True)
 
    dist.destroy_process_group()
if __name__ == "__main__":
    main()
数据切分(DistributedSampler)

DDP 下数据切分通常使用 torch.utils.data.distributed.DistributedSampler。它的工程意义是:每个 rank 只看见数据集的一个分片,并且在每个 epoch 以相同随机种子但不同偏移做 shuffle。训练循环里需要在每个 epoch 调用 sampler.set_epoch(epoch),否则多卡的 shuffle 行为容易退化为“每个 epoch 都是同一切分”。

Python
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
dataset = ...
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
loader = DataLoader(dataset, batch_size=bs, sampler=sampler, num_workers=4, pin_memory=True)
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)
    for batch in loader:
        ...
常见约束与调参入口
  • backend 选择:NVIDIA GPU 通常选 NCCL;CPU 场景通常用 Gloo(性能不同)。
  • find_unused_parameters:动态图/分支模型可能需要,但会引入开销;结构固定的训练尽量避免。
  • 梯度桶(bucket):DDP 会把梯度聚合成 bucket 做 AllReduce,bucket 大小与拓扑会影响吞吐与尾延迟。
FSDP

FSDP(Fully Sharded Data Parallel)把参数、梯度与优化器状态按 data-parallel rank 做分片,以显著降低“模型状态显存”。实践上分两条 API 主线:FSDP2(当前推荐)与 FSDP1(传统 wrapper 形态)。两者共享一个工程事实:优化器应在模型被分片之后创建,因为参数对象会被重映射。

最小 FSDP2 骨架(fully_shard)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, FSDPModule
def main():
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    model = Transformer()  # 伪代码:包含 model.layers
    for layer in model.layers:
        # 先按 block 粒度分片,避免整模型一次性 all-gather 的峰值过高。
        fully_shard(layer)
    # 最外层再包一次,让剩余未包裹参数也进入 FSDP 管理。
    fully_shard(model)
    assert isinstance(model, FSDPModule)
    # 优化器必须在 fully_shard 之后创建,否则拿到的仍是分片前参数引用。
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    ...
最小 FSDP1 骨架(FullyShardedDataParallel)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def main():
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    model = torch.nn.Linear(1024, 1024).cuda()
    model = FSDP(model)
 
    # 注意:optimizer 要在 FSDP wrap 之后创建
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    x = torch.randn(8, 1024, device="cuda")
    loss = model(x).sum()
    loss.backward()
    opt.step()
    opt.zero_grad(set_to_none=True)
 
    dist.destroy_process_group()
if __name__ == "__main__":
    main()
auto wrap 与分片策略

分片边界直接决定通信形态:对极小模块做分片会导致频繁 all-gather/reduce-scatter,吞吐明显下降。实践上常把分片边界放在较大的 Transformer block 级别,并在框架侧配合 activation checkpointing 来降低激活占用。

torchrun

torchrun 是 PyTorch 提供的分布式启动器,等价于 python -m torch.distributed.run。它负责为每个进程注入 rank/world_size/local_rank 等环境变量,并管理多机训练的 rendezvous。对于 GPU 训练,torchrun 的默认模型是“每进程一 GPU”。

常用启动参数
参数 含义 示例
--nproc-per-node 每台机器启动的进程数(GPU 训练通常等于每机 GPU 数)
Shell
1
torchrun --standalone --nproc-per-node=8 train.py
--nnodes / --node-rank 多机训练的节点数量与当前节点序号
Shell
1
torchrun --nnodes=2 --node-rank=0 --nproc-per-node=8 --master_addr=$MASTER --master_port=29500 train.py
--standalone 单机训练使用本地 rendezvous,省去显式配置
Shell
1
torchrun --standalone --nproc-per-node=8 train.py
launcher

在更大规模场景里,torchrun 之外通常还会叠一层集群 launcher(例如 Slurm 的 srun,或 K8s job controller),负责资源分配与节点编排。工程边界一般是:launcher 负责“分配哪些机器/卡”,torchrun 负责“每台机器上起哪些进程并建立通信”。

大模型分布式系统
DeepSpeed

DeepSpeed 把“大模型训练需要的显存管理与并行策略”产品化:通过 deepspeed.initialize 与一个配置文件,让训练脚本在不重写大量底层逻辑的情况下获得 ZeRO、offload、优化器与调度能力。它的关键工程入口是:安装、配置文件、启动命令与与现有训练循环的接入点。

DeepSpeed 是 Microsoft 开源的大模型训练系统,主线生态围绕 PyTorch、CUDA/NCCL、Transformers/Accelerate 和 Azure/HPC 场景展开。它在中文大模型训练项目里非常常见,但来源和维护主体属于海外开源训练系统;和 MindSpore/MindSpeed、PaddlePaddle/PaddleNLP、OneFlow 这类国产训练栈应分开理解。

ZeRO 是什么

ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 最核心的显存优化机制。传统数据并行(Data Parallelism)里,每张 GPU 都完整保存一份模型参数、梯度和优化器状态;ZeRO 的做法是把这些原本重复保存的训练状态切成 shard,分散到不同 data-parallel rank 上。这样,每张 GPU 只保留自己负责的那一片,需要计算时再通过通信临时收集所需状态。

用 AdamW 训练一个参数量为 \(P\) 的模型时,训练态显存里至少会出现三类模型状态:参数 \(\theta\)、梯度 \(\nabla\theta\)、优化器状态 \(m,v\)(以及很多实现里的 FP32 master weights)。普通数据并行会在每张 GPU 上复制这些状态;若 data-parallel 规模为 \(N\),ZeRO 的目标就是把可分片状态的单卡占用从“接近完整一份”压到“约 \(1/N\) 份”。

ZeRO 阶段 分片对象 工程含义
Stage 1 优化器状态 先切 Adam 的一阶/二阶动量、FP32 master weights 等最占空间的 optimizer state,训练循环改动最小。
Stage 2 优化器状态 + 梯度 进一步切梯度,通常是中等规模大模型训练和微调的常见起点,显存收益明显,通信复杂度仍可控。
Stage 3 优化器状态 + 梯度 + 参数 把模型参数本身也切开,单卡显存最省;前向/反向期间需要频繁 all-gather / repartition,通信、checkpoint 和调试复杂度最高。

ZeRO 与 FSDP 的目标相近,都是减少数据并行里的重复模型状态。差别主要在工程入口:ZeRO 通常通过 DeepSpeed 配置文件和 DeepSpeedEngine 接管训练循环;FSDP 属于 PyTorch 原生分片数据并行,通常通过 wrapper 或 FSDP2 的 fully_shard 接入。实际选型更多取决于团队已有训练栈、checkpoint 体系、offload 需求和框架集成方式。

安装
Shell
1
pip install deepspeed
启动
Shell
1
2
# 典型用法:用 deepspeed 作为 launcher
deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json
最小配置(ds_config.json)

DeepSpeed 的工程事实是“配置驱动”:显存分片、offload、通信重叠与一些优化器实现由 JSON 配置决定。下列示例是可落地的起点,常见改动集中在 ZeRO stage 与 offload 选项。

ds_config.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
ZeRO-3 / Offload 配置骨架
ds_config_zero3.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "offload_param": { "device": "cpu", "pin_memory": true },
    "offload_optimizer": { "device": "cpu", "pin_memory": true }
  }
}
Python 接入(deepspeed.initialize)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import deepspeed
model = ...
# 把参数迭代器单独拿出来,是为了让 initialize 能接管优化器与 ZeRO 分片。
params = model.parameters()
engine, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model,
    model_parameters=params,
    config="ds_config.json",
)
for batch in dataloader:
    # 这里假设模型 forward 直接返回 loss;真实工程里也常写成 outputs.loss。
    loss = engine(batch)
    # backward/step 必须走 engine,才能正确触发 ZeRO、AMP 和梯度累积逻辑。
    engine.backward(loss)
    engine.step()
Megatron

Megatron-LM 是面向 Transformer 预训练的参考实现体系,内置张量并行(TP)、流水并行(PP)、以及与 NVIDIA 生态加速库的集成。它更像“训练系统工程模板”:直接复用仓库里的预训练脚本与并行参数,然后在其上叠加数据、模型结构与实验约束。

安装入口(推荐容器路径)

Megatron-LM/Megatron Core 对 CUDA、PyTorch、Transformer Engine、通信库的版本组合敏感。工程上通常优先使用 NGC 的 PyTorch 容器作为基线,再在容器内安装/开发 Megatron 相关代码,减少 ABI 与编译链不一致带来的问题。

Shell
1
2
3
4
5
6
7
# 直接进入官方 NGC PyTorch 容器,把 CUDA、cuDNN 与编译链基线固定住。
docker run --runtime=nvidia --gpus all -it --rm \
  -v /path/to/megatron:/workspace/megatron \
  -v /path/to/dataset:/workspace/dataset \
  -v /path/to/checkpoints:/workspace/checkpoints \
  -e PIP_CONSTRAINT= \
  nvcr.io/nvidia/pytorch:25.04-py3
启动模式(示意)

Megatron-LM 的典型启动方式仍是 torchrun,但会显式配置 TP/PP 规模,并把全局 batch 拆成 micro-batch + accumulation。下列命令展示“最少参数框架”,具体模型/数据参数由脚本与配置决定。

Shell
1
2
3
4
5
6
7
8
9
# 用 torchrun 启动分布式训练。
torchrun --nproc-per-node=8 pretrain_gpt.py \
  --tensor-model-parallel-size 2 \
  --pipeline-model-parallel-size 2 \
  --micro-batch-size 1 \
  --global-batch-size 128 \
  --sequence-length 4096 \
  --train-iters 1000 \
  ...
Megatron Core

Megatron Core 是可组合库形态:把训练大型 Transformer 所需的关键模块与系统优化能力封装成 API,供自定义训练框架调用。它提供 pip 安装与示例训练循环,工程上适合“需要 Megatron 的并行与算子能力,但不想完全使用 Megatron-LM 全栈脚本”的团队。

Shell
1
2
uv pip install megatron-core
torchrun --nproc-per-node=2 examples/run_simple_mcore_train_loop.py
NeMo

NeMo 是 NVIDIA 把“大模型训练配方、集群启动、Megatron Core 并行策略、checkpoint 管理”打包成体系化工作流后的结果。它更适合“训练系统本身就是长期资产”的团队:数据准备、预训练、继续预训练、SFT、PEFT、恢复训练、导出部署都沿着同一套 recipe 与 launcher 组织,而非靠零散脚本拼起来。2025 年之后的文档主线已经明显转向 NeMo 2.0 / AutoModel / NeMo-Run:本地工作站、Slurm、Kubernetes、Docker、SkyPilot 这些执行后端都被统一到 launcher 抽象里。

NeMo 当前最值得单独理解的两个点是 recipe 与 distributed checkpoint。recipe 负责把模型、数据、训练器、并行度、日志与恢复策略写成一份配置;distributed checkpoint 则允许在不同 TP/PP 规模之间恢复训练,这一点在“白天 8 卡调通、夜里 64 卡正式跑”或“先做 PEFT、再切 full finetune”时非常实用。

命令/API/函数

automodel config.yaml

说明

NeMo AutoModel 官方把 CLI 作为首选入口。它要求 YAML 里包含 recipe._target_,并用统一命令同时覆盖单卡、多卡、以及后续的集群扩展。对日常微调来说,这比直接记住底层 Python 脚本路径更稳,因为 recipe 升级后 CLI 兼容面通常更好。

示例

Shell
1
2
automodel examples/llm_finetune/llama3_2/llama3_2_1b_squad.yaml
automodel --nproc-per-node 2 examples/llm_finetune/llama3_2/llama3_2_1b_squad.yaml

命令/API/函数

run.run(...) / NEMORUN_HOME

说明

NeMo-Run 负责把同一份训练 recipe 投递到不同执行后端。它支持 local、Docker、Slurm、Kubernetes 等执行器,因此特别适合“本地先跑通,再推到集群”的工作流。和裸写 sbatch 相比,这层抽象把代码打包、环境镜像、实验元数据目录与远端执行方式收敛成一套配置面;实验元数据默认落在 ~/.run,可通过 NEMORUN_HOME 改写。

示例

Python
1
2
3
4
5
6
import nemo_run as run
 
task = ...  # recipe 在别处配置好;这里不重复混入模型细节。
# 本地调试先走 LocalExecutor;切集群时再换 Slurm/K8s/Docker executor。
executor = run.LocalExecutor()
run.run(task, executor=executor)  # 运行记录会写进 NEMORUN_HOME,方便回溯和恢复。

命令/API/函数

distributed_fused_adam

说明

这是 NeMo 的 distributed optimizer 入口。它把 Adam 的优化器状态与 master parameters 在 data-parallel 组内分片,解决“大模型还没算起来,优化器状态先把显存吃满”的常见问题。对预训练和 full finetune,这个开关通常比再抠一点 activation 更早见效。

示例

YAML
1
2
3
4
model:
  optim:
    # 让优化器状态在 data-parallel 组内分片,而非每卡都完整复制。
    name: distributed_fused_adam

命令/API/函数

distributed checkpoint

说明

NeMo 文档明确支持用不同张量并行与流水并行规模恢复训练。工程含义是:checkpoint 不再死绑某一套并行拓扑,集群资源变化或实验阶段切换时更容易继续跑。对于多节点预训练,这通常比单文件权重更重要,因为真正昂贵的是“恢复后的继续迭代能力”,而非一次性导出。NeMo 1.x 常见的是 .nemo 归档;NeMo 2.x 则更强调分布式 checkpoint 格式与并行 save/load。

示例

Python
1
2
3
4
import nemo.collections.asr as nemo_asr
 
# 经典入口:从 .nemo 归档恢复“权重 + 配置”。
model = nemo_asr.models.EncDecCTCModel.restore_from("asr.nemo")

1
2
3
4
典型迁移流程:
1. 用小规模 TP/PP 在开发环境调通 recipe。
2. 保存 sharded checkpoint。
3. 夜间切到更大的 TP/PP 或更多节点继续训练。
MindSpeed

MindSpeed 是面向昇腾(Ascend)训练栈的大模型训练加速组件,常和 MindSpore、MindFormers、CANN、以及 Megatron 风格的并行训练脚本一起出现。它的核心价值在于把大模型训练里高频的并行、显存、通信和算子适配工作收敛到 Ascend 生态内,减少从 NVIDIA/CUDA 训练脚本迁移到 NPU 集群时的系统改造成本。

选 MindSpeed 的首要条件是硬件环境。目标集群以 Ascend NPU 为主时,MindSpeed 负责承接大模型训练里的加速与适配;目标集群以 NVIDIA GPU 为主时,DeepSpeed、Megatron、FSDP、NeMo 通常是摩擦更小的路线。这个边界非常关键,因为分布式训练系统的很多能力最终受通信库、编译器、设备运行时和算子覆盖度约束。

安装与版本边界

MindSpeed 的安装需要和 Ascend 驱动、CANN、MindSpore 或 PyTorch NPU 版本矩阵对齐。真实项目里应先固定设备驱动与 CANN 版本,再选择对应的 MindSpeed / MindFormers / torch_npu 组合。

Shell
1
2
3
4
# 只展示源码安装入口;生产环境必须按 Ascend/CANN/框架版本矩阵固定分支。
git clone https://gitee.com/ascend/MindSpeed.git
cd MindSpeed
pip install -e .
Megatron 风格脚本的适配入口

MindSpeed 文档和示例常见的工程形态是:保留 Megatron 风格训练脚本的外形,在启动脚本中引入 MindSpeed 的 adaptor 或训练参数,让底层并行、通信和优化逻辑切到 Ascend 适配实现。下面是概念骨架,具体导入路径与参数需要以当前版本文档为准。

Python
1
2
3
4
5
6
7
8
# adaptor 的作用是把 Megatron 风格训练脚本接入 Ascend/MindSpeed 适配层。
# 真实项目中应以当前 MindSpeed 版本的示例脚本为准。
import mindspeed.megatron_adaptor
 
from pretrain_gpt import train
 
# train 仍然沿用 Megatron 风格入口;硬件适配由 adaptor 和启动参数接管。
train()
ColossalAI

ColossalAI 的工程定位和 DeepSpeed、Megatron、FSDP 都不完全相同。它更像一层“并行与显存优化注入器”:保留普通 PyTorch 训练循环的外形,再通过 Booster + Plugin 把数据并行、ZeRO、Gemini、Hybrid Parallel、混合精度等能力装进去。对于已经有自己训练脚本、又不想完全改写成另一套框架的人,这种接入方式很有吸引力。

官方文档把插件的适用区间讲得很清楚:Torch DDP 更适合小模型;Torch FSDP 与 LowLevelZeroPlugin 适合中等规模;GeminiPlugin 面向更大的模型与异构内存管理;HybridParallelPlugin 则面向超大模型或长序列场景,把 TP、PP、DDP/ZeRO、Shardformer 与 pipeline manager 统一起来。这个“按插件选路径”的设计,是 ColossalAI 最值得掌握的入口。

命令/API/函数

Booster(plugin=...)

说明

Booster 是当前主线 API,用来接管模型、优化器、criterion、dataloader 与 lr scheduler。它替代了旧时代的 colossalai.initialize 思路,把“并行策略选择”前置成一个显式对象,训练循环本身则尽量保持普通 PyTorch 形态。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
 
plugin = GeminiPlugin()  # 这里选 Gemini,是因为它直接提供 Zero-3 + chunk 化异构内存管理。
# bf16 是大模型训练更常见的稳定精度起点。
booster = Booster(mixed_precision='bf16', plugin=plugin)
 
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
    model=model,                 # 模型在这里被包进并行/显存优化层,而非手工逐段改写。
    optimizer=optimizer,         # 优化器也一起接管,避免参数分片后状态不同步。
    # loss 计算入口保留,但反向传播改由 booster.backward 管理。
    criterion=criterion,
    dataloader=dataloader,       # 某些插件会顺手处理 sampler / device placement。
    lr_scheduler=lr_scheduler,   # 调度器一并纳入,减少“模型被改写后调度器失配”的风险。
)
 
for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    booster.backward(loss, optimizer)  # 这一层统一处理 AMP、梯度缩放与分布式同步细节。
    optimizer.step()
    optimizer.zero_grad()

命令/API/函数

GeminiPlugin

说明

GeminiPlugin 封装的是 chunk-based、heterogeneous memory management 风格的 Zero-3 路线。它适合显存已经逼近边界、但又不想立刻把整个项目迁到更重训练系统里的场景。官方推荐区间是 10B 以上模型和中小规模集群,这与它对跨节点带宽的要求相匹配。

示例

Python
1
2
3
4
5
6
7
from colossalai.booster.plugin import GeminiPlugin
 
plugin = GeminiPlugin(
    precision='bf16',   # 让参数与计算都走 bf16,兼顾吞吐和数值稳定性。
    # 由插件自动决定张量放在 GPU 还是 host memory,减少手调负担。
    placement_policy='auto',
)

命令/API/函数

HybridParallelPlugin

说明

HybridParallelPlugin 面向“需要 TP + PP + DP/ZeRO 组合”的超大模型训练。它把 Shardformer、pipeline manager、mixed precision 与并行策略绑成一体,适合超长序列、大词表或 60B 以上模型。这个插件的价值不在于多一层包装,而在于它把原本彼此独立的并行配置收拢成一个对象,减少 TP/PP/Zero 参数彼此打架的概率。

示例

Python
1
2
3
4
5
6
7
8
from colossalai.booster.plugin import HybridParallelPlugin
 
plugin = HybridParallelPlugin(
    tp_size=2,                 # 张量并行切两份,先解决单卡放不下的问题。
    pp_size=2,                 # 流水并行切两段,降低单卡激活与参数峰值。
    zero_stage=1,              # 在数据并行维度继续压缩优化器状态。
    precision='bf16',          # 混合精度由插件统一接管,避免和外层 AMP 重复配置。
)

命令/API/函数

colossalai.launch_from_torch() / colossalai run

说明

这是 ColossalAI 最低成本的接入路径。训练脚本里调用 launch_from_torch() 读取 rank/world_size 等环境变量;命令行则用 colossalai run 或 torchrun 起多进程。这样可以先把训练 loop 保持不变,再逐步引入 Booster 与 Plugin,而非一开始就全面重构。

示例

Python
1
2
3
import colossalai
 
colossalai.launch_from_torch()  # 从 launcher 注入的环境变量建立默认进程组。

Shell
1
2
colossalai run --nproc_per_node 4 train.py
colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 train.py

命令/API/函数

booster.save_model / booster.load_model

说明

Booster 不只接管训练,也接管 checkpoint I/O。对大模型训练,这一点非常重要,因为分片策略、并行包装与 safetensors 格式都会影响保存和恢复路径。 shard=True 可以直接写成 Hugging Face 风格分片目录; low_cpu_mem_mode 则是在恢复阶段用更低 CPU 内存换取更慢加载。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
save_dir = "./ckpt"
 
booster.save_model(
    model,
    checkpoint=save_dir,
    shard=True,               # 写出分片目录,便于大模型落盘与迁移。
    size_per_shard=1024,      # 每个 shard 上限 1024 MB,减少单文件过大带来的 I/O 问题。
    use_safetensors=True,     # 用 safetensors 提高安全性与加载稳定性。
)
 
booster.load_model(
    model,
    checkpoint=save_dir,
    low_cpu_mem_mode=True,    # 恢复时优先压低 CPU 峰值内存,适合大模型重载。
)
大模型分布式系统怎么选
系统 自己的特色 优先选择场景
DeepSpeed ZeRO、Offload、DeepSpeedEngine、配置驱动训练循环,和 Transformers / Accelerate / OpenRLHF 生态集成成熟。 显存主要卡在 optimizer/gradient/parameter state,项目已经基于 PyTorch/HF,或需要 RLHF 框架直接复用 ZeRO 配置。
Megatron-LM / Megatron Core TP、PP、context parallel、MoE 与 NVIDIA 大模型预训练范式结合紧,适合从训练脚本层面控制并行策略。 从零预训练或继续预训练超大 Transformer,团队能维护复杂训练脚本和并行拓扑。
NeMo 把 recipe、launcher、分布式 checkpoint、Megatron Core 与集群执行器组合成体系化工作流。 训练系统要长期维护,需要从本地调试平滑迁移到 Slurm/Kubernetes/多节点集群。
ColossalAI Booster + Plugin 接入方式清晰,可在普通 PyTorch loop 外围注入 ZeRO、Gemini、Hybrid Parallel。 已有自定义训练循环,希望保留代码外形,同时加入并行与显存优化。
MindSpeed 面向 Ascend/CANN/NPU 的大模型训练适配与加速,常用于昇腾集群上的 Megatron 风格训练迁移。 目标硬件是 Ascend NPU,需要和 MindSpore / MindFormers / torch_npu / CANN 生态协同。
PyTorch FSDP PyTorch 原生分片数据并行,接入点在模型 wrapper 或 FSDP2 的 sharding API,减少外部框架依赖。 希望保持 PyTorch-native 栈,配合 Accelerate/Trainer 或自定义训练循环做参数分片。
分片与状态管理
ZeRO

ZeRO(Zero Redundancy Optimizer)的核心思想是:把数据并行中本来每卡都复制一份的三类状态(优化器状态、梯度、参数)分片到不同 rank,从而把“模型状态显存”从 O(N) 降到 O(N/world_size)。DeepSpeed 的 ZeRO Stage 1/2/3 分别对应分片优化器状态、再分片梯度、再分片参数;Stage 3 需要在前向/反向时做参数聚合与再分片。

参数分片

参数分片常见两条路线:DeepSpeed ZeRO-3 与 PyTorch FSDP。两者目标一致,但接入点与约束不同;工程选型通常取决于现有训练栈(Transformers/Accelerate 生态 vs 自定义训练框架)、offload 需求,以及 checkpoint 与集群拓扑迁移的要求。

优化器状态分片

优化器状态(例如 Adam 的一阶/二阶动量)往往占据巨大的显存/内存。ZeRO-1/2 对这部分的分片收益很直接;当显存仍不足时,DeepSpeed 还支持把状态 offload 到 CPU/NVMe(ZeRO-Offload/ZeRO-Infinity),但带宽会成为新的瓶颈,必须依赖重叠与流水化来降低代价。

CUDA 软件栈
CUDA

CUDA 版本与驱动版本的不匹配是训练系统最常见的部署故障源。最小的工程实践是区分两个事实:nvidia-smi 反映的是驱动能力,nvcc 反映的是 toolkit;二者不一致并不必然是错误,但 toolkit 版本若高于驱动支持上限就无法正常工作。部署时以 NVIDIA 的 CUDA Compatibility 文档为准,并用 PyTorch 的 torch.version.cuda 与运行时实际 driver 做交叉验证。

Shell
1
2
nvidia-smi
nvcc --version

Python
1
2
3
import torch
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
cuDNN

cuDNN 是卷积、归一化、注意力等基础算子的关键实现来源之一。训练部署阶段更重要的工作是保证:驱动、CUDA toolkit、cuDNN 版本与 GPU 架构落在官方支持矩阵内,并与 PyTorch 及扩展库(FlashAttention、xFormers、bitsandbytes)的编译参数保持一致。

NCCL

NCCL 是 NVIDIA GPU 场景下最常用的分布式通信后端。大规模训练里,通信问题往往表现为:hang、极慢、或者跨机带宽只有理论值的一小部分。排障的第一入口是 NCCL 环境变量日志与网络接口选择。

NCCL 常用环境变量
变量 作用 示例
NCCL_DEBUG 开启 NCCL 日志(INFO/WARN)
Shell
1
export NCCL_DEBUG=INFO
NCCL_DEBUG_SUBSYS 按子系统过滤 NCCL_DEBUG 输出
Shell
1
export NCCL_DEBUG_SUBSYS=INIT,NET
NCCL_SOCKET_IFNAME 指定/过滤用于通信的网卡接口(支持 include/exclude 语法)
Shell
1
2
3
export NCCL_SOCKET_IFNAME=eth
export NCCL_SOCKET_IFNAME==eth0,eth1
export NCCL_SOCKET_IFNAME=^docker
NCCL_IB_DISABLE 显式禁用 InfiniBand(在 IB 配置不完整时可用于快速隔离问题)
Shell
1
export NCCL_IB_DISABLE=1
NCCL_P2P_DISABLE 禁用 GPU P2P(用于排查 P2P/拓扑相关问题,性能通常会下降)
Shell
1
export NCCL_P2P_DISABLE=1
kernel 与算子级优化
PyTorch SDPA(内置注意力后端选择)

在不引入额外依赖的情况下,优先使用 PyTorch 的 torch.nn.functional.scaled_dot_product_attention。它会在支持时选择更高性能的 attention 后端(例如 FlashAttention / memory-efficient / cuDNN / math 实现),并将“后端差异”收敛到同一 API 上。后端选择也可通过上下文管理器显式控制。

Python
1
2
3
4
5
6
7
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
 
# 显式指定后端有助于调试“为什么没有走 FlashAttention”这类性能问题。
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
Triton

Triton 是面向 GPU kernel 的 Python DSL:通过 @triton.jit 与 triton.language(tl.*)API,把常见内核模式写成可编译的 Python 函数。工程上 Triton 常作为“自定义融合算子”的落地点:当 PyTorch 原生算子组合产生大量中间张量或访存瓶颈时,用 Triton 把多步计算融合成一个 kernel。

安装
Shell
1
pip install triton
最小 Triton kernel 骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import triton
import triton.language as tl
 
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements: tl.constexpr, BLOCK: tl.constexpr):
    # pid 表示当前 program 实例编号;它决定本次 kernel 处理输入向量的哪一段。
    pid = tl.program_id(axis=0)
    # offs 是本 block 对应的全局元素下标。
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    # 最后一个 block 往往不满;mask 防止尾部越界 load/store。
    mask = offs < n_elements
    x = tl.load(x_ptr + offs, mask=mask, other=0.0)
    y = tl.load(y_ptr + offs, mask=mask, other=0.0)
    # Triton 的核心价值是把“读 x + 读 y + 写 out”融合成单个 kernel。
    tl.store(out_ptr + offs, x + y, mask=mask)
FlashAttention

FlashAttention 是“精确 softmax attention”的高性能实现:通过 IO-aware 的分块与融合,把注意力的访存与中间张量开销显著压低。工程上常见三条接入路径:

  • 直接使用 PyTorch SDPA,让 PyTorch 在运行时选择 FlashAttention 后端(不额外引入 Python 包)。
  • 引入 flash-attn 包,显式调用算子。
  • 通过 xFormers 的 attention ops 或上层框架开关间接启用。

显式安装 FlashAttention 涉及 CUDA 扩展编译,部署约束主要集中在:CUDA toolkit、GPU 架构与编译工具链一致性。

安装
Shell
1
pip install flash-attn
  • 编译型依赖:安装过程可能会编译 CUDA 扩展,通常需要可用的 CUDA toolkit、以及可工作的编译链(例如 ninja)。
  • 版本约束:FlashAttention 的不同分支/包对 PyTorch 与 CUDA 版本有明确要求;环境固定时以官方 README 的支持矩阵为准。
  • 平台差异:Linux 是最常见的稳定路径;Windows/非常规组合通常更容易落到源码编译与 ABI 问题上。
使用(算子级接入)
Python
1
2
3
4
from flash_attn import flash_attn_func
 
# q,k,v: (batch, seqlen, nheads, headdim) 等布局依赖具体函数签名
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
flash-linear-attention(fla)

flash-linear-attention(常见简称 fla)提供线性注意力与相关模块的高性能实现,核心依赖是 PyTorch 与 Triton。工程上它更像一个“可插拔的层/算子库”:只有当模型架构实际使用了这些层(例如某些线性注意力/SSM/hybrid 模块)时,训练与推理才会受益。

安装
Shell
1
2
3
4
pip install flash-linear-attention
 
# 仅安装核心 kernel/ops(更轻依赖)
pip install fla-core
升级约束
Shell
1
2
3
# 升级前,先卸载两个包避免版本冲突
pip uninstall fla-core flash-linear-attention -y
pip install -U flash-linear-attention fla-core
其他加速组件
xFormers

xFormers 提供一组可组合的 Transformer 组件与优化算子。其中最常见的工程入口是 memory-efficient attention:通过统一接口选择不同高性能后端。安装上通常优先用预编译 wheel;当 PyTorch 版本或 CUDA 组合偏离主流时,才会退回到从源码编译。

Shell
1
pip install xformers

Python
1
2
3
4
from xformers.ops import memory_efficient_attention
 
# q,k,v 的形状与布局取决于 xFormers 版本与具体 backend
out = memory_efficient_attention(q, k, v)
bitsandbytes

bitsandbytes 提供低比特量化算子与 8-bit/4-bit 训练组件,常见用途是:QLoRA 的 4-bit Linear 层,以及 8-bit Adam 优化器状态以降低显存/内存占用。工程上它的关键点是:安装与平台兼容(CUDA/ROCm/CPU 路径)、以及把模型中的 Linear/Embedding 替换为 bnb 对应模块。

Shell
1
pip install bitsandbytes
常用模块速查

命令/API/函数
bitsandbytes.nn.Linear4bit

说明
QLoRA 4-bit Linear

示例

Python
1
2
3
4
5
6
7
8
import torch.nn as nn
from bitsandbytes.nn import Linear4bit
# 先准备一层普通全精度 Linear,模拟“已有 FP16/FP32 权重如何迁到 4bit 模块”。
fp16_model = nn.Linear(64, 64)
q_model = Linear4bit(64, 64)
# load_state_dict 只负责搬运权重数值;量化通常在模块迁移到 CUDA 时真正发生。
q_model.load_state_dict(fp16_model.state_dict())
q_model = q_model.to(0)  # 量化通常在 .to("cuda") 触发

命令/API/函数
bitsandbytes.nn.Linear8bitLt

说明
8-bit Linear

示例

Python
1
2
from bitsandbytes.nn import Linear8bitLt
layer = Linear8bitLt(4096, 4096).to(0)

命令/API/函数
bitsandbytes.optim.Adam8bit

说明
8-bit Adam 优化器

示例

Python
1
2
import bitsandbytes as bnb
opt = bnb.optim.Adam8bit(model.parameters(), lr=1e-4, min_8bit_size=16384)
数值精度与显存策略
数值精度

混合精度(AMP)是现代训练的默认手段:用 torch.amp.autocast 让部分算子在低精度执行,同时在需要数值范围的地方保留 FP32。对于 FP16 训练,通常需要 torch.amp.GradScaler 做梯度缩放;对于 BF16 训练,很多场景只用 autocast 即可。

Python
1
2
3
4
5
6
7
8
9
10
11
import torch
scaler = torch.amp.GradScaler("cuda")
for batch in loader:
    opt.zero_grad(set_to_none=True)
    with torch.amp.autocast("cuda", dtype=torch.float16):
        # 这里假设模型 forward 返回带 .loss 的对象;HF/TRL 训练栈经常如此约定。
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.step(opt)
    # update 会根据本步是否溢出动态调整下一轮的缩放因子。
    scaler.update()
TF32(矩阵乘加路径)

TF32 只影响部分 FP32 矩阵乘(matmul/conv)在 Tensor Core 上的执行路径。它属于“性能换数值精度”的系统开关,通常通过 PyTorch 的 backend 选项控制。

Python
1
2
3
4
5
6
7
import torch
 
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
 
# 新版本也可用更高层的 matmul 精度策略
torch.set_float32_matmul_precision("high")
训练稳定性策略
重算策略

Activation checkpointing(重算)用计算换显存:前向不保存中间激活,反向时按需重新执行前向片段。PyTorch 的直接入口是 torch.utils.checkpoint.checkpoint。工程上需要明确它的副作用:重算会改变“前向执行次数”,涉及 RNG 或跨设备拷贝的代码必须被审计,否则可能出现非确定性或性能退化。

Python
1
2
3
4
5
6
import torch
from torch.utils.checkpoint import checkpoint
def block(x):
    # 被 checkpoint 包住的子图不会保存完整激活,反向时会重算这一段前向。
    return layer2(layer1(x))
y = checkpoint(block, x, use_reentrant=False)
量化训练与低比特微调

低比特微调常见路线是 QLoRA:权重以 4-bit 存储,计算与梯度在更高精度上进行,训练的增量参数由 LoRA 承担。工程落地通常依赖 bitsandbytes 的 4-bit Linear 与上层微调框架(PEFT/Transformers);底层约束集中在:GPU/驱动/CUDA 兼容、量化算子是否可用、以及与 FSDP/ZeRO 的组合边界。

工程组合边界
  • 4-bit 权重 + 分片:参数分片(FSDP/ZeRO-3)与 4-bit 权重量化都在改写参数表示,组合时需要确认框架对“量化权重的 all-gather/重分片”路径是否支持。
  • offload:把模型状态 offload 到 CPU/NVMe 会引入额外带宽瓶颈,必须配合 micro-batch、重算与通信重叠,否则吞吐会显著下降。
  • 验证方式:先在单机单卡确认量化算子可用,再扩展到单机多卡,最后扩展到多机,逐层隔离问题源。
模型交换、导出与部署格式

“能训练”与“能部署”之间隔着一条很长的工程链路:模型从训练框架导出成某种中间表示,再由运行时加载并在特定硬件上执行。真正决定链路质量的是导出语义是否稳定、后端算子是否覆盖、部署环境是否可复现。

这一节按部署侧常见的三条路线组织:

  • 通用交换:PyTorch → ONNX → ONNX Runtime
  • NVIDIA GPU 推理:PyTorch/ONNX → TensorRT 或 TensorRT-LLM
  • Intel CPU/iGPU 推理:ONNX → OpenVINO(可选 IR 转换)→ OpenVINO Model Server

最后补齐两类“本地模型分发格式”:safetensors(安全权重)、GGUF/GGML(llama.cpp 系列推理栈),以及 Hugging Face Hub 的下载与离线部署。

ONNX

ONNX(Open Neural Network Exchange)是交换层:把训练框架里的前向计算图与权重,以跨框架可读的形式表达出来。部署侧更关心两个版本概念:IR version(中间表示版本)与 opset version(算子集合版本)。opset 变更意味着算子语义或签名变化,直接影响“模型能否被某个 runtime 正确执行”。

安装
Shell
1
pip install onnx
导出:PyTorch → ONNX(推荐 torch.export / TorchDynamo 路线)

PyTorch 的 ONNX 导出路线在持续演进。工程上更推荐走 torch.export/TorchDynamo 为基础的导出路径(例如 torch.onnx.export(..., dynamo=True)),以获得更稳定的图捕获与更好的算子覆盖。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
 
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(8, 4)
    def forward(self, x):
        # ONNX 导出只关心前向图;这里故意保持 forward 为纯张量运算。
        return torch.relu(self.linear(x))
model = MyModel().eval()
x = torch.randn(1, 8)
torch.onnx.export(
    model,
    (x,),
    "my_model.onnx",
    input_names=["x"],
    output_names=["y"],
    dynamo=True,  # 推荐的新导出逻辑
)

模型权重很大时,需要考虑外部权重(external data):单个 ONNX 文件存在体积限制,导出时可以把权重拆到额外文件中,并让 ONNX 图引用它们。

ONNX 基础验证

部署前至少做两步静态检查:载入与 checker。最常见的失败来自 opset 不匹配、导出遗漏常量折叠、或动态控制流无法被捕获。

Python
1
2
3
4
import onnx
 
m = onnx.load("my_model.onnx")
onnx.checker.check_model(m)
常见坑
  • opset 不匹配:导出使用了较新的 opset,但 runtime/后端只支持较旧 opset,表现为“能 load 但执行时报不支持的算子/属性”。
  • 动态形状:ONNX 本身可以表达动态维度,但后端是否支持、以及是否需要 shape inference/优化,是另一回事。实践里建议先固定 batch/seq 长度跑通,再逐步放开。
  • 大模型外部权重:超过单文件限制时,ONNX 可能以 external data 形式拆成多文件。部署与转换时必须保证目录结构完整,并确保 runtime 能在正确的 base_dir 下加载外部权重文件。
ONNX Runtime

ONNX Runtime(ORT)是执行层:加载 ONNX 图并在不同硬件后端上执行。它通过 Execution Providers(EP)把算子下沉到不同加速库(CPU、CUDA、TensorRT 等)。部署编程上,核心对象是 InferenceSession。

安装(CPU / GPU)

实践里同一个 Python 环境通常只安装一个 ORT 包(CPU 或 GPU)。GPU 包覆盖大部分 CPU 功能,但仍需要关注 CUDA/cuDNN 与驱动版本匹配。

Shell
1
2
3
4
5
# CPU
pip install onnxruntime
 
# GPU(默认 CUDA 12.x)
pip install onnxruntime-gpu
最小可用推理代码
Python
1
2
3
4
5
6
7
8
9
import numpy as np
import onnxruntime as ort
 
# 不传 providers 时通常先走 CPU provider;语义先跑通,再切 GPU/TensorRT provider。
sess = ort.InferenceSession("my_model.onnx")
 
# 输入 key 必须和导出时的 input_names 对齐;ORT 不会替你猜字段名。
x = np.random.randn(1, 8).astype(np.float32)
y = sess.run(None, {"x": x})
Execution Provider 选择与回退

服务端通常需要显式选择 EP,并提供“失败回退到 CPU”的策略。最常见的做法是按优先级传入 providers 列表。

Python
1
2
3
4
5
6
7
8
import onnxruntime as ort
 
# provider 列表的顺序就是回退顺序:CUDA 失败时,再落到 CPU。
providers = [
    "CUDAExecutionProvider",
    "CPUExecutionProvider",
]
sess = ort.InferenceSession("my_model.onnx", providers=providers)
常用API

命令/API/函数
ort.InferenceSession

说明
加载模型、选择 EP、执行推理

示例

Python
1
2
3
4
sess = ort.InferenceSession(
    "my_model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)

命令/API/函数
sess.get_inputs()

说明
枚举输入名、dtype、shape(用于接入层校验)

示例

Python
1
2
3
inputs = sess.get_inputs()
for i in inputs:
    print(i.name, i.type, i.shape)

命令/API/函数
sess.run

说明
执行推理

示例

Python
1
outputs = sess.run(None, {"x": x_np})
常见坑
  • CUDA 版本对齐: onnxruntime-gpu 与本机 CUDA/cuDNN/驱动组合必须匹配,否则会出现 provider 初始化失败或动态库缺失。
  • 输入 dtype:推理时 numpy 的 dtype 必须与模型输入一致(例如 fp32),否则会报类型不匹配。
  • EP 支持度:同一个模型在不同 EP 上算子覆盖不同,部署前需要用真实模型做冒烟测试,遇到不支持算子时要么回退到 CPU,要么调整导出图/替换算子。
TensorRT

TensorRT 是 NVIDIA GPU 上的推理优化与运行时:它把 ONNX 模型解析到网络图,再由 Builder 构建优化后的 engine(plan)。构建通常离线完成,线上只加载 engine 并执行。

安装

TensorRT 提供多种安装方式(容器、Debian、pip wheel)。工程上常见的两条路径是:

  • 开发环境:用 pip 安装 tensorrt(或精简运行时变体),配合本机 CUDA 与驱动。
  • 生产环境:以容器为主,把驱动与 CUDA 依赖固化在镜像与运行时约束里。
Shell
1
2
3
4
5
6
7
8
9
# pip 安装(示例:按实际平台与版本选择合适包名)
# 先把 pip 升到足够新,避免解析不到当前平台对应的 TensorRT wheel。
python -m pip install -U pip
 
# 标准 Python 运行时包,适合本机开发或最小接入。
pip install tensorrt
 
# 只需要加载现成 engine 的瘦镜像可考虑 lean runtime。
pip install tensorrt_lean
ONNX → TensorRT engine(两种入口)

TensorRT 有两个常见入口:命令行工具(快速验证)与 Python/C++ API(集成到服务)。典型流程是用 TensorRT ONNX parser 导入 ONNX,再由 Builder 生成 engine。

1) 命令行(trtexec)
Shell
1
2
3
4
5
6
7
8
9
10
11
12
# 典型用法:从 ONNX 构建 engine
# 构建或测试 TensorRT engine。
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan
 
# 若需要 FP16
# 构建或测试 TensorRT engine。
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan --fp16
 
# 动态 shape(示例:按你的真实输入名与维度填写)
# 构建或测试 TensorRT engine。
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan \
  --minShapes=x:1x8 --optShapes=x:16x8 --maxShapes=x:64x8
2) Python API(OnnxParser + Builder)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorrt as trt
 
# WARNING 足够看到 parser/build 失败,但不会被 INFO 级日志刷屏。
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
# EXPLICIT_BATCH 让 batch 维进入网络定义;现代 ONNX/TensorRT 路线基本都依赖它。
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
 
with open("my_model.onnx", "rb") as f:
    ok = parser.parse(f.read())
    if not ok:
        # parser.num_errors / parser.get_error(i) 可用于定位不支持算子
        raise RuntimeError("ONNX parse failed")
常见坑
  • engine 可移植性:TensorRT engine 受平台、GPU 架构、TensorRT 版本与构建参数影响。需要跨版本或跨架构复用时,必须显式开启对应的兼容模式;否则默认情况下不具备可移植性。
  • 动态形状与 profile:动态 shape 通常需要显式设置 optimization profile,否则构建或运行会失败。
  • 算子覆盖:ONNX parser 报错时,优先从“导出是否落到标准算子”排查;其次考虑 TRT 插件或改写模型。
TensorRT-LLM

TensorRT-LLM 是面向 LLM 的 TensorRT 构建与运行时栈:提供 Python API 和服务端组件,把 LLM 的 KV cache、注意力优化、量化与服务化接口封装成一条更完整的部署链。快速落地通常走官方容器路线,然后用 trtllm-serve 启动 OpenAI-compatible server。

安装与启动(推荐容器路线)

TensorRT-LLM 更偏“完整部署栈”,依赖 CUDA、TensorRT、编译链与模型支持矩阵。工程上优先选择官方预构建容器,在容器内完成转换、build 与 serve。

在线部署:trtllm-serve(OpenAI-compatible)
Shell
1
2
# 容器内启动服务(示例)
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

启动后可访问标准 OpenAI 端点(例如 /v1/chat/completions)。

Shell
1
2
3
4
5
6
7
8
9
# 用最小请求验证 TensorRT-LLM 服务是否已经能完成一次完整生成。
curl -X POST http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "messages": [{"role": "user", "content": "Where is New York?"}],
    "max_tokens": 32,
    "temperature": 0
  }'

真实部署里, trtllm-serve 通常还要补一层配置:并行规模、batch/token 上限、KV cache 预算、served model name、日志与 tracing 端点。它更接近“一个完整的服务进程”,而非单纯把 Python 模型对象暴露成 HTTP。

离线推理:LLM API

TensorRT-LLM 同时提供 Python 侧的 LLM API:给定 Hugging Face repo 或 checkpoint,API 负责加载、优化与推理编排。对工程团队而言,这条路径适合把“推理服务”嵌入到现有 Python 服务栈中,但需要更细致的版本与环境锁定。

Python
1
2
3
4
5
6
7
8
9
10
11
from tensorrt_llm import LLM, SamplingParams
 
llm = LLM(
    # 直接把模型身份交给 TensorRT-LLM,由其接管构建与执行细节
    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)
params = SamplingParams(
    temperature=0.0,  # 这里显式关采样,得到稳定的离线回归结果
    max_tokens=32,    # 离线批处理同样需要 token 上限,否则容易把单次任务拖得过长
)
outputs = llm.generate(["Where is New York?"], sampling_params=params)
高频部署参数:model_dir / output_dir / tp_size / pp_size

从工程脚本分布看,TensorRT-LLM 的高频入口明显集中在“模型输入目录、engine 输出目录、以及并行切分规模”。这说明它的主工作流核心是“先准备构建输入,再显式产出 engine 制品,然后按并行规模上线”。

命令/API/函数
--model_dir

说明
指定待转换/待构建的模型目录。它通常指向 Hugging Face checkpoint、本地导出目录或经过预处理的权重目录,是整个构建链路的输入端。

示例

Shell
1
2
3
trtllm-build \
  --model_dir /models/llama3-hf \
  --output_dir /engines/llama3_tp2

命令/API/函数
--output_dir

说明
指定 TensorRT-LLM 产出的 engine 制品目录。它核心是后续 serve/load 真正要消费的部署产物。

示例

Shell
1
2
3
trtllm-build \
  --model_dir /models/llama3-hf \
  --output_dir /engines/llama3_tp2

命令/API/函数
--tp_size

说明
张量并行规模。它决定一个模型副本会被切到多少张 GPU 上,也直接影响 engine 构建结果是否能在目标机器上落地。

示例

Shell
1
2
3
4
trtllm-build \
  --model_dir /models/llama3-hf \
  --output_dir /engines/llama3_tp4 \
  --tp_size 4

命令/API/函数
--pp_size

说明
流水并行规模。模型超过单机或需要按层切分时,会和 tp_size 一起决定 engine 的并行拓扑。

示例

Shell
1
2
3
4
5
trtllm-build \
  --model_dir /models/llama3-hf \
  --output_dir /engines/llama3_tp4_pp2 \
  --tp_size 4 \
  --pp_size 2

这类参数和 vLLM 的 tensor_parallel_size/ pipeline_parallel_size 在工程意图上是同一类东西:描述“一个模型副本如何切到多卡”。区别在于,TensorRT-LLM 往往把这件事更早固化到 engine 构建产物里,因此部署时必须让“构建时并行拓扑”和“上线时硬件拓扑”保持一致。

OpenVINO

OpenVINO 面向 Intel CPU/iGPU/加速器的推理栈。它既能直接加载 ONNX,也能把 ONNX 转换成 OpenVINO IR(xml+bin)。如果关注加载延迟或希望提前做图优化,通常会先把 ONNX 转成 IR。

安装
Shell
1
pip install openvino
ONNX → OpenVINO IR(Python API)
Python
1
2
3
import openvino as ov
 
ov_model = ov.convert_model("your_model_file.onnx")
OpenVINO Model Server:LLM QuickStart

OpenVINO Model Server(OVMS)提供服务化部署。它支持以 Docker 启动,指定 --source_model 从 Hugging Face 拉取已转换的 OpenVINO 模型,并暴露 OpenAI 风格 API。

Shell
1
2
3
4
5
6
7
8
9
10
11
# OVMS 会把下载/转换后的模型文件写入这个目录,便于持久化和排错。
mkdir -p models
# 这里直接拉起模型服务进程;source_model 指向可被 OVMS 识别的模型仓库。
docker run -d --rm -p 8000:8000 \
  -v $(pwd)/models:/models:rw \
  openvino/model_server:2026.1-gpu \
  --source_model OpenVINO/Qwen3-8B-int4-ov \
  --model_repository_path models \
  --task text_generation \
  --rest_port 8000 \
  --target_device GPU

OVMS 也支持用 OpenAI Python client 直接调用(base_url 指向 OVMS)。

常见坑
  • 外部权重 ONNX:若 ONNX 以 external data 拆成多文件,必须保持主 onnx 与外部权重文件的目录关系可发现。
  • 硬件选择:OVMS 需要正确设置 target_device,并保证容器或主机具备对应设备节点与驱动。
本地推理权重格式
safetensors

safetensors 的工程定位是“替代 pickle 的安全权重格式”:可做 zero-copy 加载,并显式避免任意代码执行风险。它主要服务于权重分发与加载,不承担跨硬件优化执行这一层职责。

Shell
1
pip install safetensors

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from safetensors import safe_open
from safetensors.torch import save_file
import torch
 
# safetensors 约定传入“张量名 -> 张量值”的普通 dict。
tensors = {"w": torch.zeros((2, 2))}
save_file(tensors, "model.safetensors")
 
# safe_open 允许按 key 惰性读取,不必一次性 materialize 整个文件。
loaded = {}
with safe_open("model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        loaded[k] = f.get_tensor(k)

按 key 读取(或切片读取)是 safetensors 的常见用法:多 GPU 分片加载、按需加载 embedding 表等场景,会用它降低峰值内存。

GGML

GGML 是 llama.cpp 早期使用的权重格式/生态名词之一。当前工程实践中,GGUF 更常作为“可分发的最终产物”。GGML 更适合理解为历史兼容路径:遇到旧模型时需要能识别与迁移。

GGUF

GGUF 是 llama.cpp 生态的主流分发格式。典型链路是:从 Hugging Face 模型(safetensors/pytorch)转换到 GGUF,再选择量化方案,最后交给 llama.cpp 或 Ollama 运行。llama.cpp 仓库提供了 convert_hf_to_gguf.py 等脚本作为转换入口。

Shell
1
2
3
4
5
6
7
8
9
# 转 GGUF 最常见的起点仍然是 llama.cpp 仓库自带的转换脚本。
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
 
# requirements.txt 主要提供 tokenizer 与权重转换依赖。
pip install -r requirements.txt
 
# 先看 --help 再决定模型目录、量化格式和 tokenizer 文件来源。
python convert_hf_to_gguf.py --help

GGUF 的常见坑是 tokenizer 与特殊 token:转换时必须保证 tokenizer 文件齐全(例如 sentencepiece model 或 BPE merges/vocab),否则会出现“能推理但输出严重异常”的隐蔽故障。

模型获取与分发工具
huggingface_hub

huggingface_hub 是下载与缓存的编程入口:它把“下载”变成版本化缓存,并返回本地路径。缓存路径指向的文件不应该被修改,否则会污染缓存并产生难以排查的线上问题。

Shell
1
pip install huggingface_hub
常用API

命令/API/函数
hf_hub_download

说明
下载单个文件(带缓存与 revision)

示例

Python
1
2
3
4
5
6
7
8
from huggingface_hub import hf_hub_download
 
# revision 最好固定到 tag 或 commit,避免同一个 repo 名字在不同时间解析到不同内容。
path = hf_hub_download(
    repo_id="lysandre/arxiv-nlp",
    filename="config.json",
    revision="main",
)

命令/API/函数
snapshot_download

说明
下载整个仓库(支持 allow/ignore patterns)

示例

Python
1
2
3
4
5
6
7
8
from huggingface_hub import snapshot_download
 
# 只把部署真正需要的文件拉下来,避免把不必要的大 bin 文件一起缓存。
local_path = snapshot_download(
    repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    revision="main",
    allow_patterns=["*.safetensors", "*.json"],
)

命令/API/函数
HfApi(endpoint=...)

说明
对接私有 Hub/镜像,显式指定 endpoint

示例

Python
1
2
3
4
from huggingface_hub import HfApi
 
api = HfApi(endpoint="https://huggingface.co")
models = api.list_models(search="bert")

命令/API/函数
hf_hub_url

说明
构造下载 URL(用于调试/审计)

示例

Python
1
2
from huggingface_hub import hf_hub_url
url = hf_hub_url("lysandre/arxiv-nlp", "config.json")

命令/API/函数
hf CLI

说明
登录、下载、缓存管理

示例

Shell
1
2
3
4
5
6
7
8
# 安装与查看
hf --help
 
# 按 revision 固定下载
# include/exclude 让下载集合显式可控,便于做离线部署清单。
hf download TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
  --revision main \
  --include "*.safetensors" --exclude "*.bin"

huggingface_hub 的环境变量用于把缓存与认证变成可运维配置: HF_HOME、 HF_HUB_CACHE、 HF_TOKEN 等。它们通常在 import 时读取,生产环境必须保证“进程启动前配置好”。

离线模式与版本固定

离线部署首先要把所有文件按 revision 固定并落到可控目录,然后再在离线环境启用 offline 开关。仅仅让 from_pretrained 在离线环境里可调用,并不足以保证整条部署链稳定。

Shell
1
2
3
4
5
6
7
8
9
10
11
12
# 1) 先下载到可控缓存目录(示例)
# 先把 Hub 缓存放到大磁盘或共享卷,避免默认写满系统盘。
export HF_HOME=/data/hf
export HF_HUB_CACHE=/data/hf/hub
# 第一次必须在线下载完整模型,离线模式只负责“后续不再触网”。
hf download TinyLlama/TinyLlama-1.1B-Chat-v1.0 --revision main
 
# 2) 再把运行环境切到离线(示例:以你实际依赖版本为准)
# 这三层开关分别作用于 hub、transformers 和 datasets。
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
export HF_DATASETS_OFFLINE=1
  • 对关键模型,revision 固定到 tag/commit hash,并把下载产物做 manifest(文件列表 + 哈希)。
  • 离线开关:Hugging Face 生态里存在多层 offline 变量(例如 hub、transformers、datasets 各自的 offline 模式)。团队需要以“实际依赖版本”为准做一次演练,确认每个库在离线环境的行为一致。
hf-mirror 与中国大陆下载

中国大陆环境常见问题是“访问 Hugging Face 资源不稳定”。工程上有两条可控路径:

  • 内部镜像:自建 mirror 服务,把下载变成内网依赖。
  • 第三方镜像:通过环境变量把下载 base url 指向镜像站点。

对镜像/私有 Hub 的对接,优先使用显式 endpoint:Python 侧使用 HfApi(endpoint=...);命令行侧则使用统一的网络出口与缓存目录策略。部分旧版本 API 说明也记录了 HF_ENDPOINT 这类环境变量用法,但它是否生效取决于你实际安装的 huggingface_hub 版本。

Shell
1
2
# 把 Hugging Face 的下载端点切到镜像(示例)
export HF_ENDPOINT=https://hf-mirror.com

需要评估:hf-mirror 属于第三方服务,可靠性与合规性需要业务自行评估。能自建镜像或把模型产物纳入制品库(artifact repository)的团队,应优先选择可控方案。

典型部署链路速查
目标 链路 适用场景 常见坑
跨框架推理 PyTorch → ONNX → ONNX Runtime 多语言客户端、跨平台部署 opset/EP 不匹配,动态形状处理
NVIDIA GPU 高性能推理 PyTorch/ONNX → TensorRT engine 低延迟/高吞吐服务 engine 不可跨 GPU/版本复用,profile/插件
NVIDIA LLM 服务化 HF checkpoint → TensorRT-LLM → trtllm-serve OpenAI-compatible LLM server 模型支持矩阵、量化/精度要求、容器化依赖
Intel LLM 服务化 ONNX → OpenVINO(可选 IR)→ OVMS CPU/iGPU 部署、边缘与本地服务 外部权重、设备映射与驱动
本地量化推理 HF → GGUF → llama.cpp/Ollama 本地开发、边缘设备 tokenizer 文件缺失、量化质量与兼容性
推理引擎与服务系统

推理引擎与服务系统把“模型权重 + 推理优化”交付为“可稳定承载并发请求的 API”。服务端需要长期管理 prefill/decode 调度、KV cache 生命周期、批处理策略、流式输出、并发隔离、模型加载与热更新 等工程问题。

LLM 推理引擎

常见推理栈可以按落地点分为三类:面向 GPU 的在线推理引擎(vLLM、SGLang、LMDeploy、TGI)、面向本地/边缘的运行时(llama.cpp、Ollama),以及更贴近硬件厂商优化栈的服务框架(如 TensorRT-LLM / Triton 一类)。工程选型通常先定两件事:服务端是否提供 OpenAI-compatible API,以及是否需要多 GPU/多节点的原生支持。

推理栈 自己的特色 优先选择场景
vLLM PagedAttention、continuous batching、OpenAI-compatible server、离线批推理与服务端参数面成熟。 通用高吞吐在线服务,尤其是多租户 chat/completions、批量生成、RAG 后端推理。
SGLang 推理控制流、结构化输出、tool parser、reasoning parser、RL rollout 控制接口更突出。 需要多步推理编排、Agent/工具调用、在线 RLHF rollout 或复杂结构化生成。
LMDeploy TurboMind/PyTorch 双后端、量化与中文开源模型部署生态完整,离线 pipeline 与在线 API server 都覆盖。 国产模型、VLM、多量化路线验证,以及希望在同一工具链里完成量化、部署和批推理。
TGI Hugging Face 官方服务端路线,和 Hub、tokenizer、模型配置、Prometheus 指标体系结合紧。 团队已经重度使用 Hugging Face Hub,希望推理服务与 HF 生态保持一致。
TensorRT-LLM / Triton 面向 NVIDIA GPU 的深度优化路线,强调 engine 构建、算子融合、低延迟和生产部署治理。 延迟/吞吐指标极端敏感,团队有能力维护模型编译、engine 版本和 Triton 部署链路。
llama.cpp / Ollama 本地与边缘运行友好,CPU/GPU 混合、GGUF、桌面/开发机部署成本低。 个人开发、边缘设备、内网原型、小模型本地服务,或需要极低运维成本的演示系统。
vLLM

vLLM 是面向高吞吐服务端推理的引擎。它通过 PagedAttention 管理 KV cache,并采用 continuous batching 处理变长请求,从而在并发场景下维持 GPU 利用率。生产系统里最常用的入口是 OpenAI-compatible server(HTTP)。

安装与启动

安装时优先按官方指引为目标平台准备匹配的 PyTorch(CUDA/ROCm/CPU),再安装 vLLM。常见启动方式如下:

vLLM: run OpenAI-compatible server
Shell
1
2
3
4
5
6
7
# pip install vllm
# 启动 vLLM 服务。
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
  --host 0.0.0.0 \
  --port 8000 \
  --dtype auto \
  --api-key token-abc123

容器化部署通常使用官方镜像,并把 Hugging Face 缓存目录挂载到容器内,避免重复下载权重:

vLLM: run with Docker image (pattern)
Shell
1
2
3
4
5
6
7
# 把模型缓存目录挂到宿主机,避免容器重启后重新下载整套权重。
docker run --gpus all \
  -v ~/.cache/huggingface:/root/.cache/huggingface \
  -p 8000:8000 \
  --ipc=host \
  vllm/vllm-openai:latest \
  --model meta-llama/Meta-Llama-3-8B-Instruct
OpenAI-compatible API(调用与流式输出)

OpenAI-compatible server 的目标是复用现有的 OpenAI SDK。调用方式只需要把 base_url 指向自托管服务即可:

Call vLLM with OpenAI Python SDK
Python
1
2
3
4
5
6
7
8
9
10
11
from openai import OpenAI
 
# 只需要改 base_url,就能复用现成的 OpenAI SDK 与上层业务封装。
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
resp = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0.2,
    max_tokens=128,
)
print(resp.choices[0].message)

流式输出通常使用 SSE;请求里把 stream 设为 true 即可:

OpenAI-compatible streaming (generic curl)
Shell
1
2
3
4
5
6
7
8
9
# 这条请求验证的是 SSE 流式协议,不代表服务已经调到最佳吞吐。
curl http://localhost:8000/v1/chat/completions \
  -H 'Content-Type: application/json' \
  -H 'Authorization: Bearer token-abc123' \
  -d '{
    "model": "meta-llama/Meta-Llama-3-8B-Instruct",
    "messages": [{"role":"user","content":"Hello!"}],
    "stream": true
  }'
关键服务参数(常用 Engine Args)

vLLM 的参数很多,但多数服务化场景只需要围绕“显存预算、并发上限、上下文长度、缓存开关”做控制。

参数 含义 典型影响
--max-model-len 最大上下文长度 上限越大,KV cache 预算越高;并发上限通常随之下降
--gpu-memory-utilization 显存预算比例 控制 KV cache 可用空间,影响 OOM 风险与吞吐
--max-num-batched-tokens 每步调度的 token 预算上限 增大可提升吞吐,但可能增加尾延迟
--max-num-seqs 并发序列数上限 控制并发度与资源争用,影响延迟与稳定性
--kv-cache-dtype KV cache 存储精度 更激进的 KV 精度可降低显存/带宽,但需要评估质量影响
--enable-prefix-caching 启用前缀缓存(Prompt Caching) 前缀重复多的业务可显著减少 prefill 成本
--generation-config generation_config 的优先级策略 影响默认采样参数来源,避免线上采样行为“悄悄变了”
多 GPU 与多节点部署

服务化推理的常见拓扑是“单实例多 GPU(TP/PP)”与“多副本横向扩展”。横向扩展更依赖网关/LB 做副本路由;多数引擎把 KV cache 作为进程内状态,因此跨副本共享缓存并不常见。

Serving topology (concept)
1
2
3
4
5
6
7
8
single instance:
  [client] -> [inference server] -> [1 GPU]
 
single instance, multi-GPU:
  [client] -> [inference server] -> [TP/PP over N GPUs]
 
replicas:
  [client] -> [LB / gateway] -> [replica-1] / [replica-2] / ...
SGLang

SGLang 以“推理编排能力 + OpenAI-compatible API”为核心卖点。工程上它常用于需要多步推理控制流、工具调用编排、以及对思维链/推理输出有结构化处理的在线系统。

安装与启动

SGLang 的启动入口可以是 sglang serve 或 python -m sglang.launch_server,两者本质上都是启动一个 OpenAI-compatible server。

SGLang: install & run (single node)
Shell
1
2
3
4
5
6
7
8
9
# 先装 SGLang 本体;更复杂的推理优化与多卡参数再逐步叠加。
pip install -U sglang
 
# launch_server 会直接启动 OpenAI-compatible HTTP 服务。
python -m sglang.launch_server \
  --model-path qwen/qwen2.5-0.5b-instruct \
  --host 0.0.0.0 \
  --port 30000 \
  --log-level warning
OpenAI-compatible API 调用
Call SGLang with OpenAI Python SDK
Python
1
2
3
4
5
6
7
8
9
10
11
from openai import OpenAI
 
# base_url 改成 SGLang 地址后,应用侧调用代码可以与 OpenAI/vLLM 基本保持一致。
client = OpenAI(base_url="http://localhost:30000/v1", api_key="EMPTY")
resp = client.chat.completions.create(
    model="qwen/qwen2.5-0.5b-instruct",
    messages=[{"role": "user", "content": "List 3 countries and their capitals."}],
    temperature=0,
    max_tokens=64,
)
print(resp.choices[0].message)
多节点与并行参数(概念)

SGLang 支持多 GPU 与多节点部署。多节点部署通常需要显式指定节点数量、节点 rank、以及通信初始化地址;并行规模与模型大小共同决定权重切分与显存预算。

RL 控制面:暂停生成、释放显存与热更新权重

SGLang 的一个实用能力是把 rollout 训练需要的控制接口直接暴露成服务端 API。在线 RLHF 或 agentic training 中,训练进程常常需要临时暂停生成、更新服务权重,再继续 rollout,而非每轮都整机重启服务。

SGLang RL control-plane endpoints (concept)
1
2
3
4
5
6
7
POST /pause_generation
POST /continue_generation
POST /release_memory_occupation
POST /resume_memory_occupation
POST /update_weights_from_disk
POST /update_weights_from_tensor
POST /update_weights_from_distributed

这类接口的工程意义是把“推理服务”和“训练进程”解耦成两个组件:训练侧负责产出新权重或张量切片,服务侧负责在不中断整个进程生命周期的前提下完成切换。

Prefill / Decode 解耦与服务参数面

SGLang 已经把 prefill/decode 解耦、请求时延统计、grammar parser、tool-call parser 这些生产特性收进了同一套参数面。长上下文和高并发系统里,这些参数往往比“选哪一个底座模型”更决定最终吞吐。

SGLang: PD disaggregation (pattern)
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
python -m sglang.launch_server \
  --model-path qwen/qwen2.5-7b-instruct \
  --disaggregation-mode prefill \
  --disaggregation-transfer-backend mooncake \
  --disaggregation-bootstrap-port 25000
 
python -m sglang.launch_server \
  --model-path qwen/qwen2.5-7b-instruct \
  --disaggregation-mode decode \
  --disaggregation-bootstrap-port 25000 \
  --enable-metrics \
  --reasoning-parser deepseek-r1 \
  --tool-call-parser hermes
LMDeploy

LMDeploy 是近两年在中文开源生态里非常值得单独掌握的一条推理主线。它既能做离线 pipeline() 批推理,也能直接起 OpenAI-compatible API server;同时把量化、KV cache、prefix caching、结构化输出、多模型分发这些部署细节都做进了同一套工具链。它和 vLLM 的关系更像“并列路线”而非简单替代:vLLM 更强调通用高吞吐服务;LMDeploy 则把 TurboMind、PyTorch 双后端、AWQ/GPTQ、KV cache quant、VLM 与多机分发整合得更紧。

这套栈的关键知识点有三个。第一,离线与在线入口统一,很多参数在 pipeline 与 serve api_server 之间可以一一映射。第二,后端并不只有一个:默认会优先选择 TurboMind,但也可以显式切换到 PyTorch backend。第三, cache_max_entry_count、 session_len、 --tp 这几个参数几乎直接决定 GPU 显存水位与上下文能力,是部署时最先要调的旋钮。

命令/API/函数

lmdeploy check_env

说明

部署前的环境探针。它的价值核心是先确认当前机器到底具不具备 TurboMind/PyTorch backend 所需的依赖与设备能力。多机服务里,这一步应该写进部署前检查,而非靠第一次起服务时报错再回头排查。

示例

Shell
1
lmdeploy check_env

命令/API/函数

pipeline(..., backend_config=...)

说明

离线推理入口。若不显式指定后端,LMDeploy 会按能力自动选引擎,默认优先 TurboMind。对批量生成、离线评测和服务前 smoke test,这个入口通常比先起 HTTP 服务更省事。官方文档特别强调 cache_max_entry_count 会直接控制加载权重后可用于 KV cache 的空闲显存比例,很多“OOM 看起来像模型太大”的问题,本质上都是这里设得太激进。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
 
pipe = pipeline(
    'internlm/internlm2_5-7b-chat',
    backend_config=TurbomindEngineConfig(
        # 把上下文上限显式定住,避免默认值和业务预期不一致。
        session_len=8192,
        max_batch_size=32,             # 离线批推理需要先定一个安全 batch 上限。
        # 多条 prompt 前缀重复时,可以直接省掉大量 prefill。
        enable_prefix_caching=True,
        cache_max_entry_count=0.6,     # 预留更多显存给权重与中间缓冲,先换取稳定不 OOM。
    ),
)
 
resp = pipe(
    ['Hi, please introduce yourself.'],
    gen_config=GenerationConfig(
        # 离线评测里要限制最长输出,避免单条坏样本拖垮整批任务。
        max_new_tokens=256,
        top_p=0.8,
        temperature=0.6,
    ),
)

命令/API/函数

lmdeploy serve api_server

说明

在线服务入口。它直接暴露 OpenAI-compatible API,因此可以被 OpenAI SDK、LangChain、OpenCompass 等上层系统直接接入。参数面和离线后端配置几乎一致,最常调的是 --tp、 --session-len、 --cache-max-entry-count 与端口。

示例

Shell
1
2
3
4
5
6
lmdeploy serve api_server internlm/internlm2_5-7b-chat \
  --server-port 23333 \
  --backend turbomind \
  --tp 2 \
  --session-len 8192 \
  --cache-max-entry-count 0.6

命令/API/函数

TurbomindEngineConfig / PytorchEngineConfig

说明

这两个配置对象决定底层执行后端。TurboMind 是偏性能导向的 C++/CUDA 引擎;PyTorch backend 更接近 Python 生态,扩展门槛更低。工程上常见做法是:先用 PyTorch backend 验证新模型与新模板,再切 TurboMind 做正式服务,或者在某些结构化输出/兼容性场景保留 PyTorch backend。

示例

Python
1
2
3
4
5
6
7
8
9
10
from lmdeploy import pipeline, PytorchEngineConfig
 
pipe = pipeline(
    'internlm/internlm2-chat-1_8b',
    backend_config=PytorchEngineConfig(
        session_len=4096,             # 开发验证阶段先缩上下文,减少调试时的显存占用。
        cache_max_entry_count=0.5,    # 给 kernel 编译和其他进程留出缓冲。
        enable_prefix_caching=True,
    ),
)

命令/API/函数

structured output

说明

LMDeploy 现在在 TurboMind 与 PyTorch backend 上都支持 schema-constrained generation,可以直接约束成 JSON schema、grammar 或 regex。对工具调用、结构化抽取、RAG 结果写回数据库,这类能力比“生成后再修 JSON”稳定得多。

示例

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
 
schema = {
    'type': 'object',
    'properties': {
        'name': {'type': 'string'},
        'skills': {'type': 'array', 'items': {'type': 'string'}},
    },
    'required': ['name', 'skills'],
}
 
pipe = pipeline('internlm/internlm2-chat-1_8b', backend_config=PytorchEngineConfig())
resp = pipe(
    ['Generate a short profile.'],
    gen_config=GenerationConfig(
        response_format=dict(type='json_schema', json_schema=dict(name='profile', schema=schema)),
    ),
)

命令/API/函数

lmdeploy serve proxy

说明

proxy 用来把多个 api_server 汇聚成一个统一入口,并根据 routing strategy 把请求分发到不同节点。对“多机多卡统一服务地址”或“多个模型共享外部网关”的部署,这一层非常关键。LMDeploy 文档把 Hybrid 与 DistServe 两种 serving strategy 分开,后者会显式区分 prefill 与 decode 节点。

示例

Shell
1
2
3
4
5
6
7
8
9
10
lmdeploy serve proxy \
  --server-name 0.0.0.0 \
  --server-port 8000 \
  --routing-strategy min_expected_latency \
  --serving-strategy Hybrid
 
lmdeploy serve api_server internlm/internlm2_5-7b-chat \
  --proxy-url http://0.0.0.0:8000 \
  --server-port 23333 \
  --backend turbomind

命令/API/函数

AWQ / SmoothQuant / KV cache quant

说明

LMDeploy 的量化核心是把量化产物直接接回 pipeline 与服务接口。工程上要把三件事连起来看:量化生成目录、推理后端、服务启动参数。AWQ 主要走 TurboMind 的 W4A16 路线;SmoothQuant 常走 PyTorch backend;KV cache quant 则直接作用于服务阶段的显存与带宽。

示例

Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# AWQ:产出 4bit work_dir,后续由 TurboMind 读取。
lmdeploy lite auto_awq internlm/internlm2_5-7b-chat \
  --work-dir ./internlm2_5-7b-chat-4bit
 
lmdeploy serve api_server ./internlm2_5-7b-chat-4bit \
  --backend turbomind \
  --model-format awq
 
# SmoothQuant:更常配 PyTorch backend。
lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat \
  --work-dir ./internlm2_5-7b-chat-int8 \
  --quant-dtype int8
 
lmdeploy serve api_server ./internlm2_5-7b-chat-int8 \
  --backend pytorch
 
# KV cache 量化:直接在服务侧压缩 KV。
lmdeploy serve api_server internlm/internlm2_5-7b-chat \
  --backend pytorch \
  --quant-policy 8
TGI(Text Generation Inference)

TGI 是 Hugging Face 的推理服务栈。它在“用 Docker 把 Transformers 模型服务化”上体验成熟,仍然适合存量系统维护与兼容性部署;但该项目在 2026-03-21 被归档为只读,新增特性与生态协同通常不如更活跃的推理引擎。

Router / Launcher / Model Server 拓扑

TGI 的工程结构更像三层:launcher 负责起模型分片进程,router 负责 HTTP 接入、请求排队与 token budget 控制,model server 负责真正执行推理。理解这层拓扑有助于读日志与排障,因为“请求进不来”“排队太久”“模型侧 OOM”通常分别落在不同组件里。

Docker Quickstart

最常见启动方式是官方 Docker 镜像,容器内默认在 80 端口提供服务,常见映射是主机 8080 → 容器 80:

TGI: Docker quickstart (pattern)
Shell
1
2
3
4
5
6
7
8
9
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data
 
# 把模型缓存目录挂出来,避免容器每次重启都重新下载权重。
docker run --gpus all --shm-size 1g \
  -p 8080:80 \
  -v $volume:/data \
  ghcr.io/huggingface/text-generation-inference:3.3.5 \
  --model-id $model
关键参数(text-generation-launcher)

TGI 的核心入口是 text-generation-launcher。生产里最常调整的参数集中在模型来源、多 GPU 分片与量化。

参数 含义 典型用途
--model-id 模型 ID 或本地目录 指定权重来源(Hub 或本地)
--sharded 启用多 GPU 分片 模型需要多卡容纳时
--num-shard 分片数量 控制使用多少张 GPU
--quantize 量化模式 降低显存占用与带宽压力
--max-concurrent-requests 同时受理的请求上限 直接决定入口层背压策略;值过大时容易把队列和 KV 预算一起推爆。
--max-input-tokens / --max-total-tokens 输入长度与总 token 上限 控制 prompt 和 completion 的最坏情况,属于 TGI 容量治理的核心参数。
--max-batch-prefill-tokens 一次 prefill 的 token 预算 影响长 prompt 请求能否被及时吸纳,以及 prefill 尾延迟。
--prometheus-port / --otlp-endpoint 指标与 tracing 导出端口 生产服务通常需要把性能指标和 trace 单独接到监控系统,而非只看容器 stdout。
--waiting-served-ratio / --max-batch-total-tokens / --max-waiting-tokens 队列调度与批处理预算 控制“先继续服务已在跑的请求”还是“尽快吸纳等待队列”,属于 TGI 调度器的核心吞吐/尾延迟旋钮。
--speculate / --lora-adapters 投机解码与 LoRA 服务化 前者用于降 decode 延迟,后者用于在同一服务实例上挂多份适配器。
队列预算调优

TGI 的 Router 主要关心“当前等待队列和正在服务请求各自占掉了多少 token 预算”。因此, --waiting-served-ratio、 --max-batch-total-tokens 与 --max-waiting-tokens 常常需要一起调,而非只盯住并发请求数。

TGI: scheduler-budget tuning (pattern)
Shell
1
2
3
4
5
6
text-generation-launcher \
  --model-id $model \
  --max-total-tokens 8192 \
  --max-batch-total-tokens 65536 \
  --max-waiting-tokens 2048 \
  --waiting-served-ratio 1.2
llama.cpp(本地与边缘推理)

llama.cpp 面向本地/边缘推理,围绕 GGUF 权重与多后端(CPU/CUDA/Metal 等)提供统一 runtime,并包含 OpenAI-compatible HTTP server( llama-server)。它常用于离线环境、边缘设备、或需要把推理能力分发到开发机的场景。

编译与启动
llama.cpp: build (CUDA pattern)
Shell
1
2
3
4
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
cmake -B build -DGGML_CUDA=ON
cmake --build build -j

llama-server: start (OpenAI-compatible)
Shell
1
2
3
4
./build/bin/llama-server \
  -m /path/to/model.gguf \
  --host 0.0.0.0 \
  --port 8081
常用运行参数
参数 含义 典型影响
-m, --model GGUF 模型路径 决定权重来源与量化格式
-c, --ctx-size 上下文长度 影响 KV cache 与吞吐
-t, --threads 生成阶段 CPU 线程数 CPU 推理吞吐与尾延迟
-tb, --threads-batch 批处理/预填充阶段线程数 prefill 性能
-ngl, --n-gpu-layers offload 到 GPU 的层数 GPU/CPU 负载比例与显存占用
-n, --n-predict 单次请求最多生成多少 token 直接控制最坏情况下的 decode 时长;服务端通常要和业务超时一起设计。
--parallel 并行处理请求的槽位数 影响本地服务同时处理多少条会话;开太大时容易把 CPU 与 KV cache 一起打满。
--cont-batching 持续批处理开关 让请求不用等整批凑齐再进入推理,服务化场景下吞吐与尾延迟更稳。
--metrics 暴露监控指标 便于接 Prometheus 之类的采集系统做容量规划与告警。
--embedding 启用 embedding 接口 让同一套本地 runtime 既能做生成,也能直接提供向量化服务。
--grammar-file 基于 GBNF 的约束解码 需要 JSON/结构化输出时很有用,但会额外影响吞吐与解码路径。
--tensor-split 多 GPU 权重切分比例 桌面多卡部署时比“平均切分”更灵活,适合显存大小不完全一致的设备。
Ollama(本地模型分发 + 运行时)

Ollama 把本地模型拉取、版本管理与本地 API 封装成工具链。它提供原生 API( /api/generate、 /api/chat、 /api/embed),并提供 OpenAI-compatible API 作为兼容层。

安装与启动
Ollama: install on Linux (pattern)
Shell
1
curl -fsSL https://ollama.com/install.sh | sh
原生 API(/api/*)
Ollama: generate
Shell
1
2
3
4
curl http://localhost:11434/api/generate -d '{
  "model": "gemma3",
  "prompt": "Why is the sky blue?"
}'

Ollama: chat
Shell
1
2
3
4
curl http://localhost:11434/api/chat -d '{
  "model": "gemma3",
  "messages": [{"role":"user","content":"why is the sky blue?"}]
}'
OpenAI compatibility(概念)

兼容层的目标是让应用侧复用 OpenAI SDK 与中间件。工程上需要明确两类差异:端点语义的“兼容程度”(是否实现同等字段/行为),以及模型侧默认值(chat template、默认采样参数)是否与业务一致。

国产大模型部署(以 ChatGLM 为例)

国产大模型的部署风险集中在“推理栈适配”与“默认行为一致性”。常见问题包括:chat template 与 tokenizer 不一致导致对话格式错乱,上下文长度预算误判,以及 generation_config 的默认值覆盖导致采样行为偏离预期。

ChatGLM-6B:仓库自带的最小 API 服务

ChatGLM-6B 仓库提供了一个最小 FastAPI 服务端作为 API 部署入口,用于本地验证模型与服务链路:

ChatGLM-6B: minimal API server (from official repo)
Shell
1
2
3
# in THUDM/ChatGLM-6B repo
pip install fastapi uvicorn
python api.py
在线服务与接口兼容层

接口兼容层是推理系统工程化的关键:当服务端尽可能实现 OpenAI 的请求/响应形状,应用侧可以复用 SDK、网关与观测链路,只需把 base_url 指向自托管服务即可迁移。兼容层并不自动保证“行为一致”,上线前需要在同一套 prompts 与 sampling 参数下对齐输出分布与稳定性。

Triton Inference Server

Triton 是通用推理服务系统,核心抽象是模型仓库(Model Repository):服务端通过 --model-repository 指定一个或多个仓库路径,并按固定目录布局加载模型版本。它常用于集中托管 ONNX/TensorRT/自定义后端模型,并提供 HTTP/gRPC 与监控端点。

Triton: start server (pattern)
Shell
1
2
3
4
docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 \
  -v /path/to/model-repo:/models \
  nvcr.io/nvidia/tritonserver:<tag> \
  tritonserver --model-repository=/models

Triton model repository layout (concept)
1
2
3
4
5
6
7
model-repo/
  my_model/
    config.pbtxt
    1/
      model.onnx
    2/
      model.onnx

LLM 场景下,Triton 常作为“统一 Serving 平台”,在其上集中部署 embedding、reranker、ASR/TTS 等子模型,或加载 TensorRT-LLM backend 承载生成模型。

调度、缓存与性能机制

推理系统的吞吐、延迟与成本与服务端调度强相关。LLM 在线推理通常由两个阶段组成:prefill(处理输入 prompt 并建立 KV)与 decode(逐 token 生成)。多数性能机制都在优化这两段的计算与内存路径。

请求准入与 Token Budget

线上推理系统真正的硬约束通常是 token budget:当前显存最多容纳多少上下文 token、多少 decode 中的序列、以及每个 batch 的 prefill token 上限。vLLM、TGI、SGLang、TensorRT-LLM 虽然参数名不同,但控制的都是同一件事:别让队列里同时进入的请求把 KV cache 和调度器压垮。

概念 常见参数形态 工程意义
单请求输入上限 max-input-tokens / context-length 限制超长 prompt 直接吃光显存;通常和业务侧请求校验一起生效。
单请求总 token 上限 max-total-tokens / max-model-len 同时约束 prompt + completion,决定 KV cache 最坏情况大小。
批级 prefill 上限 max-num-batched-tokens / max-batch-prefill-tokens 控制一次 prefill 能吸进多少 token,直接影响尾延迟与抖动。
并发序列上限 max-num-seqs / max-concurrent-requests 控制 decode 阶段有多少条活跃序列同时占用 KV cache。

容量规划时,最稳的做法核心是留出一部分显存余量给突发长请求、临时 batch 波动和后台管理线程。

Batching 与 Continuous Batching

batching 把多个请求合并成一次 forward 调用,提升 GPU 利用率。静态 batching 会被最长请求拖住;服务化推理通常采用 continuous batching,在 token 级别持续吸纳新请求并淘汰已完成请求,从而减少尾延迟并提升吞吐。

KV cache

KV cache 缓存历史 token 的 Key/Value,使 decode 每一步只需计算当前 token 的 Query 并与历史 KV 做注意力。代价是显存占用随上下文长度与并发序列数近似线性增长;服务端因此需要在“最大上下文长度”和“最大并发序列数”之间做显存预算分配。

Prompt Caching(前缀缓存)

前缀缓存复用“已 prefill 的前缀 KV”。当新请求与历史请求共享前缀时,服务端可以跳过重复的 prefill 计算。该机制对“系统 prompt 固定、RAG 模板固定、长前缀重复”的业务收益显著。

Speculative decoding

speculative decoding 用“草稿模型提出多个候选 token + 目标模型验证并接收其中一部分”的方式减少目标模型的 decode 步数,从而降低解码延迟。服务端通常需要同时加载目标模型与草稿模型,并为草稿模型设置独立的资源预算。

vLLM: speculative decoding (CLI pattern)
Shell
1
2
3
4
5
6
7
# 启动 vLLM 服务。
vllm serve <target-model> \
  --speculative-config '{
    "method": "draft_model",
    "model": "<draft-model>",
    "num_speculative_tokens": 5
  }'

llama-server: speculative decoding (pattern)
Shell
1
2
3
4
./build/bin/llama-server \
  -m /path/to/target.gguf \
  --model-draft /path/to/draft.gguf \
  --host 0.0.0.0 --port 8081
检索、向量与 RAG 支撑组件

RAG 工程的核心约束来自“把非结构化知识变成可检索的结构化索引”。这一层组件主要解决:如何把文档分块、生成 embedding、写入索引或向量数据库、在查询时做 ANN 召回与元数据过滤、用 reranker 提升答案相关性、以及在部署层面控制延迟、吞吐与成本。

RAG 的最小工程闭环

一条可维护的检索链路通常分成两条 pipeline:离线入库(ingestion)与在线查询(retrieval)。离线阶段负责把文档标准化、分块、embedding、建索引并落盘;在线阶段负责 query embedding、ANN 召回、过滤、重排与返回候选片段。

离线入库(Ingestion)
Python
1
2
3
4
5
# 1) 规范化文本(去掉无意义空白、统一编码、可选:去重)
# 2) chunking:长文档切成 chunk(带 overlap)
# 3) embedding:把每个 chunk 编码成 float32 向量
# 4) upsert:写入向量索引/向量数据库(同时写入 metadata)
# 5) build index:IVF/HNSW 等(有些系统是插入即维护索引)
在线检索(Retrieval)
Python
1
2
3
4
5
# 1) query -> embedding
# 2) ANN search:向量相似度召回 topK candidates
# 3) filter:按 tenant / language / source / time / ACL 等元数据过滤
# 4) rerank:Cross-Encoder / LLM rerank(可选但通常能显著提升相关性)
# 5) 返回 chunks(以及 doc_id / offsets / urls 等可追溯信息)
RAG 编排框架

RAG 编排框架负责把文档入库、检索、重排、生成和评估组织成可测试的 pipeline。它和向量数据库的关系类似“应用层工作流”和“存储/检索后端”:向量库保存 chunk 与向量,编排框架决定数据如何进入向量库、查询如何路由、结果如何被 reranker 和 generator 消费。

Haystack

Haystack 是 deepset 开源的 RAG / Agent pipeline 框架。它把 DocumentStore、Retriever、Ranker、PromptBuilder、Generator、Evaluator 等组件组织成有向 pipeline,适合工程团队把“搜索 + 生成”拆成可替换、可观测、可单测的阶段。与 LangChain 相比,Haystack 的搜索/RAG 流水线感更强;与 LlamaIndex 相比,它更强调组件图与端到端 pipeline 执行。

Shell
1
pip install -U haystack-ai

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from haystack import Document, Pipeline
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
 
store = InMemoryDocumentStore()
 
indexing = Pipeline()
# writer 负责把标准化后的 Document 写入 document store。
indexing.add_component("writer", DocumentWriter(document_store=store))
indexing.run(
    {
        "writer": {
            "documents": [
                Document(content="ZeRO 把优化器状态、梯度和参数分片到不同 rank。"),
                Document(content="vLLM 通过 PagedAttention 管理 KV cache。"),
            ]
        }
    }
)
 
query = Pipeline()
# 这里先用 BM25 做词法检索;生产 RAG 可替换成 dense retriever 或 hybrid retriever。
query.add_component("retriever", InMemoryBM25Retriever(document_store=store))
result = query.run({"retriever": {"query": "ZeRO 分片", "top_k": 3}})
print(result["retriever"]["documents"])

Haystack 适合把 RAG 当成“搜索系统 + 生成系统”的工程项目来做:每个组件有明确输入输出,离线入库和在线检索可以分别测试,DocumentStore 后端也可以替换为 Elasticsearch、OpenSearch、Qdrant、Weaviate 等。若项目重点是多 agent 状态机,LangGraph 更合适;若重点是企业知识索引与 query engine,LlamaIndex 更直接;若重点是 pipeline 可测试性与搜索组件替换,Haystack 的结构更清晰。

词法检索、混合检索与结果融合

向量检索并非检索系统的唯一主线。很多生产 RAG 系统会把词法检索(例如 BM25)保留下来,再与 dense retrieval 做融合。原因很直接:关键词、实体名、版本号、报错码这类“精确词面匹配”信号,BM25 往往比 embedding 更稳。

BM25:保留关键词与实体名的强信号

BM25 属于词法检索(lexical retrieval)。它按 query 词项在候选文本里的出现频率、逆文档频率与长度归一化来打分。工程直觉是:如果 query 里有非常关键的精确词面,例如函数名、报错字符串、SKU、药名或版本号,BM25 往往是第一道不能丢的 baseline。

Hybrid Retrieval:dense 和 lexical 同时保留

混合检索的常见形态是:先分别跑向量召回与 BM25,再把两路结果做融合,然后交给 reranker。这样做虽然多了一次召回,但整体稳定性通常高于“只用 dense”或“只用 lexical”。

Python
1
2
3
4
5
# 伪代码:两路召回,再统一重排
dense_hits = dense_retriever.search(query, top_k=50)   # embedding 负责语义相近召回
bm25_hits = bm25_retriever.search(query, top_k=50)     # BM25 负责关键词/实体名强匹配
# 先做融合,再交给 reranker 缩到最终 topN
merged = fuse(dense_hits, bm25_hits)
RRF 与 MMR:一个做融合,一个做去冗余

RRF(Reciprocal Rank Fusion)适合把多路排序结果合并成一条稳定的候选列表;MMR(Maximal Marginal Relevance)适合在候选已经够相关时,进一步压制重复信息、提升上下文覆盖面。两者经常一起出现,但解决的问题不同。

方法 作用点 工程意义
RRF 多路召回结果融合 对不同打分尺度不敏感;即使 BM25 和 dense 分数不能直接比较,也能按 rank 做稳健融合。
MMR 候选结果去冗余 减少 topN 里多个 chunk 都在重复同一段内容,让最终上下文覆盖更多信息面。
Chunking、Embedding、Reranker、缓存
Chunking(分块)

分块的目标是让每个向量对应一段“语义足够集中、可作为检索单位”的文本。工程上需要同时满足:检索召回稳定、下游生成可引用、以及入库成本可控。

策略 适用场景 工程要点
固定窗口(按 token/字符计数) 通用文本、日志、论坛等结构弱的语料 用 overlap 降低“切断引用”的概率;chunk_id 需要稳定(便于增量更新与去重)。
结构感知(按标题/段落/代码块) Markdown、HTML、技术文档、论文 保留层级路径(h1/h2/h3)作为 metadata;能显著改善可解释性与定位能力。
语义分段(按句子/主题边界) 长文档、语义跳跃频繁的内容 实现复杂;通常与结构感知结合更稳。
Embeddings(向量化接入)

embedding 既可以在进程内完成(直接加载模型编码),也可以通过独立服务完成(HTTP/gRPC embedding endpoint)。工程关注点包括:向量维度、向量归一化(cosine vs inner product)、批量化、以及版本管理(embedding 模型升级带来的全量重算成本)。

关注点 建议 原因
向量 dtype 索引侧统一 float32(或按系统支持使用 fp16/binary/sparse) 多数 ANN 实现以 float32 为主;混用 dtype 容易导致精度与兼容性问题。
cosine 相似度 向量 L2 归一化后用 inner product(IP) cosine(a,b) = a·b(当 ||a||=||b||=1);索引实现更统一。
版本管理 把 embedding_model_id 写入 metadata;变更时双写或重建 避免“同一集合混入不同 embedding space”导致召回不可解释。
批量化 离线入库用 batch encode + shard 写入 embedding 计算与写入都更容易成为瓶颈,批量化是最有效的吞吐优化。
Reranker(重排)

向量召回的 topK 通常是“相关但不够精确”的集合,reranker 用更强的匹配器(Cross-Encoder/LLM)对候选做二次排序。工程上常见的做法是:召回 topK=50~200,然后重排取 topN=5~20 作为上下文。

Python
1
2
3
4
# 伪代码:Cross-Encoder rerank
# pairs = [(query, chunk_text) for chunk_text in candidates]
# scores = reranker.predict(pairs)
# candidates = sort_by(scores)[:topN]
缓存与幂等(Cache & Idempotency)

RAG 系统的成本通常被 embedding 与 ANN 搜索吞掉。缓存应该围绕“纯函数”构建:相同输入应产生相同输出。

缓存点 Key 设计 注意事项
embedding 缓存 sha256(normalized_text) + model_id 同一文本在不同模型下向量不同;必须把 model_id 纳入 key。
检索结果缓存 sha256(query) + model_id + filters 适合高频固定 query;对实时性强(例如新闻)的库需要设置短 TTL 或禁用。
chunk 幂等 upsert doc_id + chunk_id 确保增量更新不会产生重复点;chunk_id 推荐来自稳定切分策略。
索引生命周期:回填、删除、重嵌入

很多 RAG 系统上线后真正复杂的部分是后续的索引生命周期管理:老文档如何重切块、新 embedding 模型如何迁移、文档删除怎样同步到索引、以及历史 chunk 如何避免脏数据残留。把索引生命周期当成单独子系统来设计,通常比在主检索链路里临时打补丁更稳。

动作 工程目标 常见策略
回填(backfill) 把历史文档补齐到新索引或新字段 按文档分片批量重跑 ingestion;用稳定的 doc_id + chunk_id 保证幂等。
删除(delete/tombstone) 让检索结果及时反映源文档删除 为文档维护删除标记或直接删点;强一致场景要把源存储与索引更新放进同一条任务链路。
重嵌入(re-embedding) 切换到新 embedding 模型 新旧索引双写一段时间;metadata 里显式记录 embedding_model_id,避免混库。
本地索引与嵌入式检索

本地索引适合“单机/单租户/读多写少”的场景:部署简单、延迟低、没有额外网络 hop。代价是分片、扩容与高可用需要自行实现。

Chroma

Chroma 是 Python-first 的嵌入式向量数据库,常用于 RAG 原型、单机服务和本地开发环境。它的特色是上手成本低:collection、metadata、embedding function、持久化目录都可以在一个 Python 进程里完成。工程边界也很明确:当业务进入多租户、高 QPS、复杂权限、高可用和跨节点扩容阶段,应评估 Qdrant、Milvus、Weaviate、pgvector 或云托管向量库。

Shell
1
pip install -U chromadb

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import chromadb
 
# PersistentClient 把索引与元数据写入本地目录,适合开发机和小型单机服务。
client = chromadb.PersistentClient(path="./chroma_store")
collection = client.get_or_create_collection(name="chunks")
 
collection.add(
    ids=["c1", "c2"],
    documents=[
        "ZeRO 通过分片减少数据并行的重复训练状态。",
        "Haystack 可以把检索、重排和生成组织成 pipeline。",
    ],
    metadatas=[{"doc_id": "d1"}, {"doc_id": "d2"}],
)
 
# query_texts 会走 collection 配置的 embedding function;生产系统应显式固定 embedding 模型。
hits = collection.query(query_texts=["ZeRO 的作用"], n_results=2)
print(hits["documents"])

Chroma 的选择理由通常是“先把 RAG 闭环跑起来”:它适合验证 chunking、prompt、rerank、引用回链等业务逻辑。上线后若数据规模和并发继续增长,需要提前设计迁移层,避免业务代码直接依赖 Chroma 的 collection API。

FAISS
安装
Shell
1
2
# CPU
pip install -U faiss-cpu
常用API

命令/API/函数
faiss.IndexFlatIP

说明
精确检索(inner product);常配合归一化实现 cosine

示例

Python
1
2
3
import faiss, numpy as np
d = 768
index = faiss.IndexFlatIP(d)

命令/API/函数
faiss.IndexIVFFlat

说明
IVF 近似检索(聚类倒排 + 精确扫描)

示例

Python
1
2
quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFFlat(quantizer, d, 4096, faiss.METRIC_INNER_PRODUCT)

命令/API/函数
faiss.normalize_L2

说明
向量 L2 归一化

示例

Python
1
faiss.normalize_L2(x)  # x: float32 [n, d]

命令/API/函数
write_index / read_index

说明
索引落盘与加载

示例

Python
1
2
faiss.write_index(index, "docs.faiss")
index = faiss.read_index("docs.faiss")
建库与查询(最小示例)

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import faiss
import numpy as np
 
# FAISS 对 ndarray 的典型要求是 float32、二维、按行存储。
xb = np.random.randn(10000, 768).astype("float32")
xq = np.random.randn(10, 768).astype("float32")
 
# cosine 检索最常见的做法是“先归一化,再用内积索引”。
faiss.normalize_L2(xb)
faiss.normalize_L2(xq)
index = faiss.IndexFlatIP(768)
# add 会把底库向量写进索引;真实系统还要单独维护向量行号到文档主键的映射。
index.add(xb)
 
scores, ids = index.search(xq, k=5)  # ids: [m, k]
工程权衡

FAISS 的“强项”是速度与可控性:你可以精确控制索引结构、nprobe/efSearch 等参数,并把索引作为本地文件交付。它的“短板”是系统能力:过滤、权限、在线扩缩容、分片与持久化策略都需要额外工程。

pgvector(PostgreSQL 向量扩展)
安装与启用
Shell
1
2
3
# 方式很多(源码/包管理器/Docker/托管服务)
# 核心步骤是:在目标数据库里启用扩展
psql -d your_db -c "CREATE EXTENSION IF NOT EXISTS vector;"
建表、建索引与查询(SQL)
SQL (pgvector)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
CREATE TABLE doc_chunks (
  id bigserial PRIMARY KEY,
  doc_id text NOT NULL,
  chunk_id int NOT NULL,
  chunk text NOT NULL,
  embedding vector(768) NOT NULL
);
 
-- HNSW(适合低延迟近似检索,内存开销更高)
CREATE INDEX ON doc_chunks USING hnsw (embedding vector_cosine_ops);
 
-- 查询 topK(cosine distance)
SELECT doc_id, chunk_id, chunk
FROM doc_chunks
ORDER BY embedding <=> '[0.01, -0.02, ...]'
LIMIT 10;
工程权衡

pgvector 的优势是“把向量检索融进现有 OLTP/OLAP 体系”:事务、JOIN、权限、备份与监控都沿用 Postgres。代价是向量检索性能上限通常低于专用向量数据库,尤其是高维大规模与高 QPS 场景;此外还需要对索引参数、VACUUM/ANALYZE、以及冷热数据分层有明确策略。

专用向量数据库

专用向量数据库把“高维 ANN 检索 + 元数据过滤 + 持久化 + 分布式扩展”做成标准能力,适合多租户、数据规模持续增长、需要高可用与可观测性的场景。

向量库抽象层:不要把业务代码写死在某一个后端上

工程上一个高频设计是先定义统一的向量检索接口,再按环境切换后端:开发环境用 Chroma / FAISS,生产环境用 Qdrant / Milvus / OpenSearch / Pinecone。这样做的价值在于:检索参数、metadata 过滤和索引生命周期逻辑都能集中在一处维护,而非散落在业务代码里。

Python
1
2
3
4
5
6
class VectorStore:
    def upsert(self, records): ...
    def search(self, vector, top_k, filters=None): ...
    def delete(self, ids): ...
 
store = make_vector_store(cfg)  # 通过配置决定走本地索引、Qdrant、Milvus 还是云托管后端

如果系统同时需要 sparse、dense 与 rerank,多数团队还会在这一层之上再包一个 retriever abstraction,把“召回策略组合”也做成可配置项。

Qdrant
部署(本地 Docker)
Shell
1
2
3
4
docker pull qdrant/qdrant
docker run -p 6333:6333 -p 6334:6334 \
  -v "$(pwd)/qdrant_storage:/qdrant/storage:z" \
  qdrant/qdrant
安装(Python Client)
Shell
1
2
# 可选 fastembed:在 client 侧直接做 text -> embedding(适合快速验证)
pip install -U "qdrant-client[fastembed]"
建库、写入与查询(Python)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
 
client = QdrantClient(url="http://localhost:6333")
if not client.collection_exists("chunks"):
  # 向量维度与距离度量一旦写入 collection,就成为这组数据的长期契约。
  client.create_collection(
    collection_name="chunks",
    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
  )
 
# payload 保存 doc_id、语言等结构化字段;后续过滤和回表都依赖它。
points = [
  PointStruct(id=1, vector=[0.0] * 768, payload={"doc_id": "d1", "lang": "zh"}),
]
client.upsert(collection_name="chunks", points=points)
 
# 先按 payload 过滤,再在候选集上做向量检索,是 Qdrant 最常见的在线查询模式。
f = Filter(must=[FieldCondition(key="lang", match=MatchValue(value="zh"))])
hits = client.search(collection_name="chunks", query_vector=[0.0] * 768, limit=5, query_filter=f)
工程权衡

Qdrant 的工程特点是:数据模型清晰(point + payload)、过滤与索引能力成熟、部署路径明确(本地 Docker / Helm / Cloud)。在安全层面需要显式启用鉴权与网络隔离,默认容器配置通常是无认证的开发模式。

Milvus
部署(Docker Compose)
Shell
1
2
curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh
bash standalone_embed.sh start
安装(Python SDK)
Shell
1
pip install -U pymilvus
建库、建索引与查询(Python)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from pymilvus import MilvusClient, DataType
 
client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
# schema 决定字段类型与主键规则;Milvus 不建议把所有结构化字段都当作无约束 JSON。
schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=768)
schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=256)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
  field_name="vector",
  index_type="HNSW",
  metric_type="COSINE",
  params={"M": 16, "efConstruction": 200},
)
 
client.create_collection(collection_name="chunks", schema=schema, index_params=index_params)
 
# data 是“字段名 -> 值”的记录列表;SDK 会按 schema 做类型校验。
client.insert(
  collection_name="chunks",
  data=[{"id": 1, "vector": [0.0] * 768, "doc_id": "d1", "lang": "zh"}],
)
hits = client.search(
  collection_name="chunks",
  data=[[0.0] * 768],
  limit=5,
  filter='doc_id == "d1"',
  output_fields=["doc_id"],
)
工程权衡

Milvus 的定位是面向大规模 ANN 检索的工程化系统:索引类型丰富、可分布式扩展、并提供集合(collection)层的 schema 与字段能力。它对部署与运维的要求也更高,适合“数据规模持续增长且需要系统化治理”的团队。

Weaviate

Weaviate 是对象 schema + 向量检索结合得比较紧的向量数据库。它把 collection、property、vectorizer、hybrid search、多租户和模块化能力放在同一套数据模型里,适合希望同时管理“对象字段、向量、关键词检索和生成式查询”的 RAG 系统。和 Milvus 相比,Weaviate 的数据对象模型更靠前;和 Qdrant 相比,它的 schema、vectorizer/module 与 hybrid search 入口更突出。

安装(Python Client)
Shell
1
pip install -U weaviate-client
建 collection、写入与向量查询(Python)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import weaviate
from weaviate.classes.config import Configure, DataType, Property
 
client = weaviate.connect_to_local()
try:
    # collection schema 是长期契约:字段名、类型、向量化方式会影响所有后续写入和查询。
    if not client.collections.exists("DocChunk"):
        client.collections.create(
            name="DocChunk",
            properties=[
                Property(name="text", data_type=DataType.TEXT),
                Property(name="doc_id", data_type=DataType.TEXT),
            ],
            # none 表示应用侧自己提供向量;也可以接入 Weaviate 的内置 vectorizer 模块。
            vectorizer_config=Configure.Vectorizer.none(),
        )
 
    chunks = client.collections.get("DocChunk")
    chunks.data.insert(
        properties={"text": "Weaviate 支持对象 schema 与向量检索。", "doc_id": "d1"},
        vector=[0.0] * 768,
    )
 
    # near_vector 适合应用侧已经完成 query embedding 的场景。
    result = chunks.query.near_vector(near_vector=[0.0] * 768, limit=3)
    for obj in result.objects:
        print(obj.properties["doc_id"], obj.properties["text"])
finally:
    client.close()

Weaviate 的选择理由通常是“希望 RAG 数据对象化”:除了向量本身,还要稳定维护字段 schema、混合检索、多租户、模块化 vectorizer 或 cloud/self-host 两种部署路径。若团队只需要最小向量召回服务,Qdrant 更轻;若规模和索引类型复杂度更高,Milvus 更常见;若所有业务数据已经在 Postgres,pgvector 的系统集成成本更低。

TCVectorDB(腾讯云向量数据库)

TCVectorDB 属于托管型向量数据库:实例创建、扩缩容与高可用由云平台提供,SDK 把 HTTP API 封装成 Python 类与对象模型。工程上更关注鉴权、网络连通(VPC/公网)、以及数据模型与索引类型的选择。

安装
Shell
1
pip3 install -U tcvectordb
常用API

命令/API/函数
Client

说明
SDK 主入口,负责鉴权、请求发送与资源管理;建议把 credential 放在环境变量或密钥管理系统。

示例

Python
1
client = tcvectordb.RPCVectorDBClient(url="https://<endpoint>", key="<api-key>", username="root")

命令/API/函数
Database / Collection

说明
逻辑组织层,通常按业务线或数据域拆分;collection 内部的向量维度与 metric 必须一致。

示例

Python
1
2
client.create_database_if_not_exists(database_name="rag_db")
client.create_collection_if_not_exists(database_name="rag_db", collection_name="chunks", indexes=[...])

命令/API/函数
IndexType / MetricType

说明
索引类型与相似度度量枚举,直接决定 recall、latency 与 cost 的平衡点。

示例

Python
1
2
3
4
5
6
VectorIndex(
  name="vector",
  index_type=IndexType.HNSW,
  metric_type=MetricType.COSINE,
  dimension=768,
)
写入与检索(工作流骨架)

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tcvectordb
from tcvectordb.model.collection import Embedding
from tcvectordb.model.enum import FieldType, IndexType, MetricType, EmbeddingModel, ReadConsistency
from tcvectordb.model.index import VectorIndex, FilterIndex, HNSWParams
 
tcvectordb.debug.DebugEnable = False
 
client = tcvectordb.RPCVectorDBClient(
  url="https://<your-vdb-endpoint>",    # 控制台/文档提供
  key="<your-api-key>",
  username="root",
  read_consistency=ReadConsistency.EVENTUAL_CONSISTENCY,
  timeout=30,
)
db = "rag_db"
col = "chunks"
client.create_database_if_not_exists(database_name=db)
 
# 这里把 embedding 工作下沉到服务端,业务只需要上传原始文本字段 chunk。
ebd = Embedding(vector_field="vector", field="chunk", model=EmbeddingModel.BGE_BASE_ZH)
client.create_collection_if_not_exists(
  database_name=db,
  collection_name=col,
  shard=1,
  replicas=0,
  indexes=[
    # 主键索引负责 upsert/覆盖写,不承担向量召回。
    FilterIndex(name="id", field_type=FieldType.String, index_type=IndexType.PRIMARY_KEY),
    VectorIndex(
      name="vector",
      field_type=FieldType.Vector,
      index_type=IndexType.HNSW,
      dimension=768,
      metric_type=MetricType.COSINE,
      params=HNSWParams(m=16, efconstruction=200),
    ),
    # doc_id 单独做过滤索引,后续才能按文档、租户或业务域做精确筛选。
    FilterIndex(name="doc_id", field_type=FieldType.String, index_type=IndexType.FILTER),
  ],
  embedding=ebd,
)
 
# upsert 时不必手工写 vector;SDK 会按 embedding 配置把文本转成向量。
client.upsert(
  database_name=db,
  collection_name=col,
  documents=[{"id": "c1", "doc_id": "d1", "chunk": "向量数据库用于相似度检索…"}],
)
 
# search_by_text 会先把查询字符串 embedding 化,再走 ANN 检索与 payload 过滤。
hits = client.search_by_text(
  database_name=db,
  collection_name=col,
  embedding_items=["向量数据库"],
  output_fields=["doc_id"],
  limit=5,
)
工程权衡

托管型服务的收益来自运维外包与 SLA,代价来自云厂商绑定与成本结构(存储、QPS、流量、索引构建)。当业务对数据主权、可迁移性或自定义算子有强需求时,需要评估本地自建或可移植方案(例如 pgvector 或自建 Qdrant/Milvus)。

选型与权衡(工程视角)
方案 适合 不适合
FAISS(本地索引) 单机部署、离线构建索引、极低延迟、对索引结构控制强 需要复杂过滤/权限/多租户/高可用与在线扩缩容
Chroma(嵌入式向量库) RAG 原型、单机服务、本地开发、希望快速验证 chunking / prompt / rerank 流程 多租户、高可用、复杂权限、跨节点扩容和强运维治理
pgvector(Postgres 内嵌) 已有 Postgres 体系、需要事务与 JOIN、数据规模中等 超大规模 ANN + 高 QPS 的专用检索场景
Qdrant / Milvus(自建向量库) 需要过滤、持久化、分布式扩展与稳定运维 团队缺少运维能力、或希望把运维成本完全外包
Weaviate(对象化向量库) 希望用 schema 管理文档对象、字段、向量、hybrid search 与多租户能力 只需要极简向量召回,或团队不想维护额外 schema/module 体系
TCVectorDB(托管向量库) 希望快速上线并获得云端 SLA、对云集成友好 强可迁移性需求、或需要深度定制与自托管
Agent、工具与应用编排组件

Agent 编排层解决的是“把模型调用变成可执行系统”的工程问题:任务被拆成哪些步骤、每一步调用哪个模型、工具如何注册与授权、状态如何持久化、失败如何重试、以及如何把整条调用链暴露给可观测性系统。它位于推理引擎与训练框架之上,承担流程控制、工具集成与状态管理。

从部署视角看,Agent 系统至少包含三类进程:

  • 推理后端:提供模型推理 API(OpenAI、vLLM、SGLang、TGI、TensorRT-LLM 等)。
  • 编排运行时:实现状态机/图/循环,负责发起模型调用、路由与错误处理。
  • 工具服务:把外部能力(数据库、搜索、浏览器、业务 API、文件系统)封装为工具端点,供模型以 tool calling 方式触发。
Agent 编排框架

编排框架的差异主要体现在两点:控制流的表达能力(链式、图式、事件驱动、角色流水线),以及工具调用的边界管理(schema、权限、审批、重试、隔离)。

框架 自己的特色 优先选择场景
LangChain 组件生态广、provider 适配多,适合把模型、prompt、retriever、tool 快速拼成 pipeline。 业务需要快速接入多种模型和工具,控制流相对简单,工程重点在集成速度。
LangGraph 图执行、显式状态、checkpoint、human-in-the-loop、可恢复运行是核心优势。 Agent 有循环、分支、审批、长期状态或失败恢复需求,需要把流程当状态机维护。
LlamaIndex 以数据连接、索引、retriever、query engine、tool retrieval 为中心,RAG 侧抽象更强。 主要问题是“让模型访问企业知识库/数据库/文档系统”,并需要持续治理索引与检索工具。
Haystack 组件化 pipeline 清晰,DocumentStore、Retriever、Ranker、Generator、Evaluator 能按 DAG 组织。 搜索/RAG 工程团队需要可测试、可替换、可观测的检索流水线,而非单纯 prompt glue。
DSPy 把 LLM 应用写成可优化程序,用指标驱动 prompt/module 编译。 研发阶段需要系统性优化 prompt、检索、few-shot 示例和模块组合,且有明确评估集。
AutoGen / CrewAI 强调多 agent 通信、角色、任务流和 runtime;CrewAI 更偏角色式业务流程。 任务天然可拆成多个角色或多个服务协作,且团队能约束每个 agent 的权限与输出契约。
LangChain

LangChain 更适合把模型、提示词、检索器与工具快速组装成可运行的 pipeline。它的核心抽象是“可组合组件”,典型用法是先把模型初始化成统一接口,再用可组合表达把上下游粘起来。

Shell
1
2
pip install -U langchain langchain-openai
export OPENAI_API_KEY=sk-...

Python
1
2
3
4
5
from langchain.chat_models import init_chat_model
 
model = init_chat_model("openai:gpt-5.4")
result = model.invoke("Hello, world!")
print(result)

LangChain 负责“编排代码结构”,推理仍发生在后端(OpenAI 或 OpenAI-compatible server)。工程上常见的落地方式是把 LangChain 应用包成一个 HTTP 服务(FastAPI 等),并把工具执行封装为内部函数或外部工具服务。

关注点 LangChain 侧要做什么 推理后端要做什么 工具服务要做什么
模型选择 统一模型调用入口、管理提示词与输入结构 提供 OpenAI-compatible API 或云端 API 无
工具调用 定义工具 schema、把工具结果回注入上下文 产出 tool call 请求(函数名 + JSON 参数) 执行工具、返回结构化结果
部署 将 pipeline 打包为服务,接入鉴权/限流/日志 承载并发与延迟 SLA 治理权限与审计,控制外部副作用
LangGraph

LangGraph 的定位更靠近“可控的状态机/图执行引擎”:它擅长表达长运行、可恢复、可插入人类审核的工作流。与只把 prompt 拼起来相比,它把“循环、分支、检查点、恢复与人机介入”变成一等公民。

Shell
1
pip install -U langgraph langgraph-checkpoint-sqlite

在工程实践中,LangGraph 更适合成为编排运行时本体:把 agent 的状态设计为显式结构(state),把工具调用、模型调用、审批点等写成节点(node),并通过 checkpointer 把状态落盘,从而支持重启恢复与长时间运行。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langgraph.checkpoint.sqlite import SqliteSaver
 
class State(TypedDict):
    text: str
def step(state: State) -> dict:
    # 节点函数只接收显式 state,并返回“这一步要写回 state 的字段”。
    return {"text": state["text"] + " -> next"}
builder = StateGraph(State)
# 节点名会进入运行时日志与 checkpoint 元数据,应该取稳定、可读的名字。
builder.add_node("step", step)
builder.set_entry_point("step")
 
with SqliteSaver.from_conn_string("checkpoints.sqlite") as checkpointer:
    # compile 后才得到真正可执行的图;checkpointer 决定线程状态如何落盘。
    graph = builder.compile(checkpointer=checkpointer)
    out = graph.invoke(
        {"text": "start"},
        # thread_id 是恢复与回放的主键;生产里通常映射到任务 ID 或工单号。
        config={"configurable": {"thread_id": "demo-thread-1"}},
    )
    print(out["text"])

部署上,LangGraph 常见两种形态:

  • 单体服务:应用进程内执行图,工具执行也在同进程或同机。
  • 分布式工具:图在编排服务内执行,工具通过 HTTP 或 MCP 调用外部服务,工具结果写入状态。
LlamaIndex

LlamaIndex 以“数据代理(Data Agent)”和检索增强为中心,更适合把外部知识、索引、向量库能力组织成 agent 可调用工具。在工具规模变大时,它提供“工具检索(tool retrieval)”这类机制,避免把大量函数定义塞进单次 prompt。

Shell
1
pip install -U llama-index

LlamaIndex 的工具抽象强调把函数与查询引擎包装成可检索、可调用的 Tool。工程上通常把它放在“工具侧”或“检索侧”:编排运行时(LangGraph / Agents SDK)调用 LlamaIndex 的查询/工具,再把结果回注入模型上下文。

Python
1
2
3
4
5
6
from llama_index.core.tools import FunctionTool
def get_weather(location: str) -> str:
    """Useful for getting the weather for a given location."""
    # docstring 和签名都会进入 tool schema;描述越具体,模型越不容易误用。
    return f"{location}: 25C"
tool = FunctionTool.from_defaults(get_weather, name="get_weather")
DSPy

DSPy 的定位是“用程序化结构来编写 LLM 应用,并用优化器把程序编译成更有效的提示词或权重配置”。它更适合研发阶段系统性迭代(prompt/模块组合/评估驱动),而非单纯手写 prompt 字符串。

Shell
1
pip install -U dspy

在工具调用上,DSPy 提供了 Tool 原语与 ReAct 等模式,支持使用底层模型的原生 function calling 能力。

Python
1
2
3
4
5
6
7
8
9
10
11
import dspy
def search_web(query: str) -> str:
    # 这里用普通 Python 函数占位真实搜索服务;DSPy 会把它包装成可调用工具。
    return f"Search results for {query}"
agent = dspy.ReAct(
    signature="question -> answer",
    tools=[search_web],
    max_iters=5,
)
result = agent(question="What's new in vLLM?")
print(result.answer)
AutoGen

AutoGen 把多智能体协作与通信基础设施作为一等公民,强调“runtime 负责消息与生命周期,agent 负责逻辑”。它既可用于研究型多智能体协作,也可作为生产编排底座。其工程价值通常体现在:明确的 agent runtime、组件化的模型与工具实现、以及面向多进程/多机的扩展路径。

Shell
1
pip install -U "autogen-agentchat" "autogen-ext[openai,azure]"
CrewAI

CrewAI 更偏向“角色 + 任务流水线(Flow)”表达,适合把业务流程拆成岗位式分工并固定编排。它的工程落地通常依赖明确的输入输出契约与任务边界,否则会迅速滑向不可控的多轮对话。

Shell
1
pip install -U crewai
协议与托管平台

工具调用协议的关键不在“模型能不能调用工具”,而在“工具定义是否标准化、执行是否隔离、权限是否可审计”。OpenAI function/tool calling 与 MCP 分别覆盖了两条常见路径:前者提供“模型到函数”的结构化参数通道;后者提供“工具/资源/提示词”的标准化服务协议,并允许工具以独立服务器形态部署。

OpenAI-compatible function/tool calling

OpenAI-compatible tool calling 的核心是:用 JSON Schema 定义工具参数,并让模型返回结构化的 tool call(函数名 + JSON 参数)。推理后端只负责生成 tool call;真正的工具执行必须在应用侧完成,并把结果作为后续输入再发回模型。

工具 schema 需要满足“模型可理解、服务端可校验”两类约束。一个可落地的最小约束集合包括:参数类型明确、必填项清晰、默认值可推断、以及禁止额外字段(避免模型塞入无关参数)。

字段 含义 典型示例
name 工具名(函数名) "search_docs"
description 工具描述(用于让模型选择工具) 描述越具体,误调用越少
parameters JSON Schema 参数定义
JavaScript
1
2
3
4
5
6
7
8
9
{
  "type": "object",
  "properties": {
    "query": { "type": "string" },
    "top_k": { "type": "integer", "default": 5 }
  },
  "required": ["query"],
  "additionalProperties": false
}

工程上,工具 schema 需要同时满足两类消费者:模型(用于选择与填参)与服务端(用于校验与执行)。推荐在服务端做强校验(Pydantic/zod/JSON Schema validator),并把校验失败当成工具错误返回给模型进行自修复。

在 API 形态上,不同后端的 tools 字段会有轻微差异(嵌套 {"type":"function","function":{...}} 或扁平 {"type":"function","name":...})。编排层通常在入口做一次归一化,保证内部只处理一种表示。

OpenAI Responses API

Responses API 把“生成 + 工具 + 流式事件”统一成一个接口,支持 function calling、内置工具与 MCP 工具。部署形态上,它更像一个推理后端:你的编排服务负责调用 Responses、接收 tool call 事件、执行工具并把结果回注入下一轮调用。

Shell
1
pip install -U openai

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
from openai import OpenAI
 
client = OpenAI()
def get_city_uuid(city: str) -> str:
    # 真实系统里这里通常查内部数据库或业务 API,而非返回硬编码字符串。
    return f"{city} ID: 00000000-0000-0000-0000-000000000000"
tool_mapping = {"get_city_uuid": get_city_uuid}
tools = [
    {
        "type": "function",
        "name": "get_city_uuid",
        "description": "Retrieve the internal ID for a city from the internal database.",
        "parameters": {
            "type": "object",
            "properties": {"city": {"type": "string"}},
            "required": ["city"],
            "additionalProperties": False,
        },
    }
]
response = client.responses.create(
    model="gpt-5.5",
    input="What's the internal ID for London?",
    tools=tools,
)
followup_items = []
for item in response.output:
    # Responses API 的 output 里可能混有文本、函数调用和其它事件,只处理 function_call。
    if item.type != "function_call":
        continue
    fn = tool_mapping[item.name]
    # arguments 是 JSON 字符串;应用侧必须自己反序列化并做参数校验。
    args = json.loads(item.arguments)
    tool_output = fn(**args)
    followup_items.append(
        # call_id 把本次工具输出绑定回模型刚才发起的那一次调用。
        {"type": "function_call_output", "call_id": item.call_id, "output": tool_output}
    )
if followup_items:
    response2 = client.responses.create(
        model="gpt-5.5",
        input=followup_items,
        previous_response_id=response.id,
    )
    print(response2.output_text)
Agents SDK

Agents SDK 的工程定位是“当你的应用拥有编排与工具执行权”时,提供标准化的 agent loop、handoff、guardrail、session 与 tracing。它与 LangGraph 的差异在于:前者更偏 SDK 级编排框架并与 OpenAI 生态强绑定,后者更偏通用图式编排运行时。

Shell
1
2
pip install openai-agents
export OPENAI_API_KEY=sk-...

Python
1
2
3
4
5
6
7
8
9
10
11
from agents import Agent, Runner
 
agent = Agent(
    name="Ops helper",
    # instructions 是稳定行为约束;比每次把系统提示拼进用户输入更可维护。
    instructions="Diagnose errors and suggest concrete fixes.",
    model="gpt-5.5",
)
# run_sync 更适合 CLI 工具和后台任务;Web 服务通常改用异步入口。
result = Runner.run_sync(agent, "Explain this stacktrace and propose a patch.")
print(result.final_output)
MCP

Model Context Protocol(MCP)是一套标准化协议,用于把工具、资源与提示词以“独立服务器”的方式暴露给 AI 应用。它使用 JSON-RPC 2.0,在 host/client/server 三方模型下进行能力协商与调用。MCP 的价值在于把工具系统做成可组合生态:同一个 MCP server 可以被不同 host 复用,同一个 host 也能接多个 MCP server。

FastMCP

FastMCP 以 Python 类型标注与 docstring 自动生成工具 schema,把“写工具函数”变成“发布 MCP 工具”。它适合把内部服务封装成可调用工具,并以 stdio/HTTP 形态部署。stdio 模式下必须避免向 stdout 写日志,否则会破坏 JSON-RPC 通信。

Shell
1
pip install "mcp[cli]" httpx

Python
1
2
3
4
5
6
7
8
9
10
11
12
from mcp.server.fastmcp import FastMCP
 
mcp = FastMCP("weather")
 
@mcp.tool
def add(a: int, b: int) -> int:
    """Add two numbers."""
    # 类型标注会直接进入 tool schema,客户端无需手写另一份 JSON Schema。
    return a + b
if __name__ == "__main__":
    # stdio transport 下 stdout 是协议通道;普通日志应改写到 stderr 或日志文件。
    mcp.run()

命令/API/函数
FastMCP(name)

说明
创建 MCP server 实例。通常位于 server 进程。

示例

Python
1
2
3
from mcp.server.fastmcp import FastMCP
 
mcp = FastMCP("weather")

命令/API/函数
@mcp.tool

说明
声明工具函数,自动生成 schema。通常位于 server 进程。

示例

Python
1
2
3
@mcp.tool
def add(a: int, b: int) -> int:
    return a + b

命令/API/函数
@mcp.resource

说明
暴露可读取资源(类文件数据)。通常位于 server 进程。

示例

Python
1
2
3
@mcp.resource("weather://{city}")
def get_weather(city: str) -> str:
    return f"{city}: 25C"

命令/API/函数
@mcp.prompt

说明
暴露可复用 prompt 模板。通常位于 server 进程。

示例

Python
1
2
3
@mcp.prompt
def summarize_topic(topic: str) -> str:
    return f"Summarize {topic} in one paragraph."

命令/API/函数
mcp.run()

说明
启动 server(stdio/HTTP transport)。通常位于 server 进程。

示例

Python
1
2
if __name__ == "__main__":
    mcp.run()

MCP server 开发调试常用 MCP Inspector:

Shell
1
npx -y @modelcontextprotocol/inspector npx @modelcontextprotocol/server-filesystem /path/to/dir
Workflow 与 Runtime 的边界

编排代码通常由两层构成:workflow 负责表达“步骤与依赖关系”,runtime 负责提供“可恢复、可审计、可扩展”的执行语义。把关键能力下沉到 runtime,可以减少“靠 prompt 记住状态”的不稳定性。

能力 更适合放在 workflow(图/链) 更适合放在 runtime(执行层)
状态 状态结构(state schema)、节点输入输出契约 checkpoint/线程/恢复、版本化与回放
失败处理 哪些步骤允许重试、哪些步骤必须人审 指数退避、幂等、断点续跑、死信队列
工具调用 工具集选择与路由(tool routing / retrieval) 权限、沙箱、审计日志、限流、并发隔离
观测 关键业务 span 的命名与结构化属性 trace/metrics/log 的采集、采样、落库与查询
可观测性与审计

Agent 系统的主要故障点通常落在工具调用链偏航:工具选错、参数不合法、返回值解析失败、以及重试/恢复逻辑异常。可观测性需要覆盖:每次模型调用的输入输出、每次工具调用的参数与返回值、以及每个步骤的耗时与失败原因。Langfuse 在 LangChain / LangGraph 生态中常用作 tracing 平台。

Langfuse(LangChain 回调集成)
Shell
1
2
3
4
pip install -U langfuse langchain langgraph langchain-openai
export LANGFUSE_PUBLIC_KEY=pk-lf-...
export LANGFUSE_SECRET_KEY=sk-lf-...
export LANGFUSE_HOST=https://cloud.langfuse.com

Python
1
2
3
4
5
6
7
8
9
10
11
from langfuse.langchain import CallbackHandler
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
langfuse_handler = CallbackHandler()
 
# Langfuse handler 会把 prompt、模型调用、异常与 token 用量统一写进 trace。
llm = ChatOpenAI(model_name="gpt-5.4")
prompt = ChatPromptTemplate.from_template("Tell me a joke about {topic}.")
chain = prompt | llm
resp = chain.invoke({"topic": "cats"}, config={"callbacks": [langfuse_handler]})
print(resp.content)
浏览器与工具调用组件

浏览器自动化通常以“工具”的形态接入 agent:编排层提供一个受控接口,例如 open_url、click、type、screenshot、extract_text;底层用 Playwright/Puppeteer 执行真实浏览器操作。生产部署时更常见的做法是把浏览器跑在隔离容器里,通过队列或 RPC 驱动,避免把不可信页面脚本与业务服务混跑在同一进程。

Playwright
Shell
1
2
pip install playwright
python -m playwright install

Python
1
2
3
4
5
6
7
8
9
from playwright.sync_api import sync_playwright
 
with sync_playwright() as p:
    browser = p.chromium.launch(headless=True)
    # 真实工具服务里更推荐先建 browser context,再为每个任务开独立 page。
    page = browser.new_page()
    page.goto("https://playwright.dev")
    title = page.title()
    browser.close()

如果把 Playwright 作为工具服务部署,推荐把“浏览器生命周期管理”显式化:为每个任务创建 context,任务结束后关闭 context,避免跨任务共享 cookie/session 导致串台。

Puppeteer
Shell
1
npm i puppeteer

Puppeteer 与 Puppeteer-core 的选择点在于“是否需要自动下载浏览器”。当你连接远程浏览器或自行管理浏览器镜像时,通常使用 puppeteer-core 并关闭下载。

从训练脚本到推理服务的最小闭环

工程闭环的目标是:用一套可复现的目录约定与脚本接口,把“数据准备 → 训练/微调 → checkpoint 管理 → 导出 → 启动推理服务 → 客户端调用 → 监控 → 回滚”串成一条可持续迭代的流水线。这个闭环必须满足两点:一是产物可追溯(可定位到数据版本、代码版本、超参版本),二是可回滚(任何上线问题都能在分钟级回退到上一版本)。

目录与产物约定(可回滚的最小形态)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
repo/
  data/
    raw/                       # 原始数据(不直接喂训练)
    processed/
      train.jsonl              # 训练集(SFT / 分类 / NER 等)
      eval.jsonl               # 验证集
      dataset_meta.json        # 数据摘要:hash、样本数、字段说明、生成脚本参数
  scripts/
    prepare_data.py            # raw -> processed,输出 dataset_meta.json
  train/
    train_sft.py               # 训练脚本(支持断点续训、保存 best/last)
  export/
    export_merge_lora.py       # 可选:LoRA 合并导出为“纯模型”目录
  outputs/
    runs/
      2026-05-09_210530_sft/   # 单次训练 run(可追溯、不可变)
        checkpoints/           # checkpoint-xxx
        best/                  # 指向 best checkpoint 或其导出物
        logs/                  # tensorboard / jsonl / wandb(任选其一)
        run_meta.json          # 代码版本、数据 hash、超参、环境信息
  models/
    registry/
      model_v0001/             # 可部署模型目录(merge 后或 base+adapter 信息)
      model_v0002/
    prod -> registry/model_v0002   # 生产指针(原子替换实现回滚)
  serving/
    vllm/
      serve.sh                 # 启动推理服务(读 models/prod)
      healthcheck.sh
  clients/
    call_openai.py
数据准备(raw → jsonl)

最小可维护做法是把训练数据固化成 jsonl,并明确字段语义。SFT 场景建议至少包含 prompt 与 response(或统一成 text,让样本是完整的对话模板)。验证集必须在训练前固定,避免“验证泄漏”造成误判。

prepare_data.py(最小骨架)
scripts/prepare_data.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import hashlib
import json
from pathlib import Path
 
def sha256_file(path: Path) -> str:
  # 用文件内容哈希固定住上游输入版本,便于之后追溯“这次训练到底吃了哪份原始数据”。
  h = hashlib.sha256()
  with path.open("rb") as f:
    for chunk in iter(lambda: f.read(1024 * 1024), b""):
      h.update(chunk)
  return h.hexdigest()
 
def write_jsonl(rows, out_path: Path) -> None:
  # 先确保输出目录存在,这样脚本可以在干净目录里直接运行。
  out_path.parent.mkdir(parents=True, exist_ok=True)
  with out_path.open("w", encoding="utf-8") as f:
    for r in rows:
      # 每条样本单独占一行,后续用 HF Datasets / 流式读取都更直接。
      f.write(json.dumps(r, ensure_ascii=False) + "\n")
 
def main():
  # 上游导出的原始数据;真实项目里通常来自标注平台、业务库导出或清洗脚本产物。
  raw_path = Path("data/raw/raw.json")  # 例:你的上游导出
  # 一次性读入原始 JSON,并假定其中已经包含 id / prompt / response 字段。
  raw = json.loads(raw_path.read_text(encoding="utf-8"))
 
  # 把上游格式规整成稳定 schema,避免训练脚本再处理多种脏格式。
  rows = []
  for x in raw:
    rows.append({
      "id": x["id"],
      "prompt": x["prompt"].strip(),
      "response": x["response"].strip(),
    })
 
  # 这里用固定比例切分;真实项目更常见的是按会话、文档或时间桶切分,降低泄漏风险。
  n = len(rows)
  train, eval_ = rows[: int(n * 0.98)], rows[int(n * 0.98):]
 
  # 把训练集和验证集落成两个独立 jsonl 文件,便于训练脚本直接按 split 读取。
  out_train = Path("data/processed/train.jsonl")
  out_eval = Path("data/processed/eval.jsonl")
  write_jsonl(train, out_train)
  write_jsonl(eval_, out_eval)
 
  # 同步写一份元信息,把原始文件哈希、切分规模和 schema 固化下来。
  meta = {
    "raw_path": str(raw_path),
    "raw_sha256": sha256_file(raw_path),
    "train_path": str(out_train),
    "eval_path": str(out_eval),
    "train_size": len(train),
    "eval_size": len(eval_),
    "schema": {"id": "str", "prompt": "str", "response": "str"},
  }
  Path("data/processed/dataset_meta.json").write_text(
    # 人类可读的缩进格式便于 code review 和回溯比对。
    json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
  )
 
if __name__ == "__main__":
  main()
训练与 checkpoint(可复现、可恢复)

训练脚本的最小要求:固定输入数据与超参、支持断点续训、把 checkpoint 与元信息写入 run 目录,并在训练结束后产出一个“可部署入口”(best checkpoint 或导出的模型目录)。下面示例以 TRL 的 SFTTrainer + PEFT(LoRA)为主线。

依赖安装(训练侧)
Shell
1
pip install -U transformers accelerate datasets trl peft safetensors
train_sft.py(最小骨架:SFT + LoRA)
train/train_sft.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
 
from datasets import load_dataset
from transformers import AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
 
def git_head() -> str:
  try:
    # 记录当前代码版本,之后排查“同样数据为什么结果变了”时很有用。
    return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
  except Exception:
    return "unknown"
 
def main():
  # 每次训练生成独立 run 目录,把日志、checkpoint 和元信息绑定在一起。
  run_id = datetime.now().strftime("%Y-%m-%d_%H%M%S_sft")
  run_dir = Path("outputs/runs") / run_id
  ckpt_dir = run_dir / "checkpoints"
  run_dir.mkdir(parents=True, exist_ok=True)
  ckpt_dir.mkdir(parents=True, exist_ok=True)
 
  # 基座模型优先从环境变量读取,便于同一脚本在不同实验里复用。
  base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
  train_path = "data/processed/train.jsonl"
  eval_path = "data/processed/eval.jsonl"
 
  # tokenizer 和基座模型必须严格对应,否则很容易出现 token id 错位。
  tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
  if tok.pad_token is None:
    # decoder-only 模型经常没有显式 pad_token,这里用 eos_token 兜底。
    tok.pad_token = tok.eos_token
 
  # 直接从 jsonl 读训练/验证集,避免把数据切分逻辑散落在训练脚本里。
  ds_train = load_dataset("json", data_files=train_path, split="train")
  ds_eval = load_dataset("json", data_files=eval_path, split="train")
 
  def to_text(batch):
    # 把结构化 prompt/response 拼成单个 text 字段,交给 SFTTrainer 做因果语言建模。
    text = []
    for p, r in zip(batch["prompt"], batch["response"]):
      text.append(f"### Instruction\n{p}\n\n### Response\n{r}")
    return {"text": text}
 
  # 映射后移除原列,避免 dataset_text_field 和旧列同时存在导致歧义。
  ds_train = ds_train.map(to_text, batched=True, remove_columns=ds_train.column_names)
  ds_eval = ds_eval.map(to_text, batched=True, remove_columns=ds_eval.column_names)
 
  # 训练配置同时决定优化、评估、保存和 best checkpoint 选择策略。
  args = SFTConfig(
    output_dir=str(ckpt_dir),
    max_length=2048,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    bf16=True,
    report_to="none",
  )
 
  # LoRA 只在注意力投影层挂 adapter,降低显存占用并保留基座模型可复用性。
  peft_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type="CAUSAL_LM",
  )
 
  # SFTTrainer 负责串起 tokenizer、dataset、PEFT 配置和底层 Trainer 循环。
  trainer = SFTTrainer(
    model=base_model,                 # TRL 支持传入模型 id
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    dataset_text_field="text",
    tokenizer=tok,
    peft_config=peft_cfg,
  )
 
  # 真正进入训练循环;accelerate / deepspeed 等启动器会从这里接管分布式细节。
  trainer.train()
 
  # 训练结束后优先取 best checkpoint,找不到时再退回 checkpoint 根目录。
  best = getattr(trainer.state, "best_model_checkpoint", None) or str(ckpt_dir)
 
  # 把本次实验的关键上下文固定下来,后续导出和上线都从这份元信息追溯。
  meta = {
    "run_id": run_id,
    "base_model": base_model,
    "git_head": git_head(),
    "train_path": train_path,
    "eval_path": eval_path,
    "best_checkpoint": best,
    "training_args": args.to_dict() if hasattr(args, "to_dict") else vars(args),
  }
  (run_dir / "run_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
 
  # 额外写一个纯文本指针,让导出脚本不需要再解析 TrainerState JSON。
  (run_dir / "best").write_text(best, encoding="utf-8")
 
if __name__ == "__main__":
  main()

多卡启动(最小形态)
Shell
1
accelerate launch train/train_sft.py
断点续训与 checkpoint 策略

断点续训的最小实现是:训练启动时检测 output_dir 下最近的 checkpoint,并把它作为 resume_from_checkpoint 输入。线上训练任务应固定 save_total_limit,避免磁盘被历史 checkpoint 填满导致任务失败。

导出与上线包(从 checkpoint 到“可部署模型目录”)

PEFT/LoRA 训练常见有两种上线形态:

  • 形态 A:部署 base + LoRA adapter(推理侧按请求或按版本加载 adapter),上线快、存储小。
  • 形态 B:把 LoRA 合并到 base(导出为纯模型目录),推理侧只加载一个目录,上线更简单。
export_merge_lora.py(形态 B:合并导出)
export/export_merge_lora.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
from pathlib import Path
 
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
 
def main():
  # 导出脚本只关心三件事:base 模型、adapter 目录、最终输出目录。
  ap = argparse.ArgumentParser()
  ap.add_argument("--base_model", required=True)
  ap.add_argument("--adapter_dir", required=True)   # trainer 产出的 adapter checkpoint
  ap.add_argument("--out_dir", required=True)       # models/registry/model_vXXXX
  args = ap.parse_args()
 
  # 提前创建输出目录,保证后续 save_pretrained 可以直接写入。
  out_dir = Path(args.out_dir)
  out_dir.mkdir(parents=True, exist_ok=True)
 
  # tokenizer 必须和 base 模型一起导出,否则线上推理会出现词表不一致。
  tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
  # 合并时不需要把基座模型放到 GPU,CPU 路径更稳,也更适合作为离线导出任务。
  base = AutoModelForCausalLM.from_pretrained(
    args.base_model,
    torch_dtype="auto",
    device_map="cpu",
  )
  # 先把 LoRA adapter 挂到 base 上,再做 merge。
  model = PeftModel.from_pretrained(base, args.adapter_dir)
  model = model.merge_and_unload()
 
  # 用 safetensors 导出,减少 pickle 风险,并让后续服务端加载更标准。
  model.save_pretrained(out_dir, safe_serialization=True)
  tok.save_pretrained(out_dir)
 
if __name__ == "__main__":
  main()
版本化与可回滚指针

上线包放在 models/registry/model_vXXXX,生产指针 models/prod 是一个符号链接。切换版本通过“原子替换 symlink”实现回滚。

上线/回滚(示例)
Shell
1
2
3
4
5
# 上线:切到新版本
ln -sfn "$(pwd)/models/registry/model_v0002" "$(pwd)/models/prod"
 
# 回滚:切回旧版本
ln -sfn "$(pwd)/models/registry/model_v0001" "$(pwd)/models/prod"
启动推理服务(vLLM OpenAI 兼容服务)

推理服务侧的最小目标是提供稳定的 HTTP API,并将“模型路径/版本”从代码里剥离出来(通过 models/prod 指针决定)。下面示例使用 vLLM 的 OpenAI-Compatible Server。

安装(推理侧)
Shell
1
pip install -U vllm
serve.sh(最小启动脚本)
serving/vllm/serve.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#!/usr/bin/env bash
# 遇到未定义变量、非零退出码或管道错误时立即失败,避免服务在半坏状态启动。
set -euo pipefail
 
# 统一通过 prod 软链接决定线上模型版本,切换版本时不需要改脚本本身。
MODEL_DIR="$(cd "$(dirname "$0")/../.." && pwd)/models/prod"
 
# 把 vLLM 日志级别显式固定,便于排查线上问题。
export VLLM_LOGGING_LEVEL=INFO
 
# 直接启动 OpenAI-compatible server;上线时通常由 systemd / supervisor / k8s 接管这个命令。
vllm serve "$MODEL_DIR" \
  --host 0.0.0.0 \
  --port 8000 \
  --dtype auto \
  --served-model-name prod \
  --api-key "token-abc123"
健康检查与 smoke test
serving/vllm/healthcheck.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env bash
# 任何一个探针失败都直接让脚本退出非零,便于被外层编排系统感知。
set -euo pipefail
 
# 第一步先检查服务是否能正常列出模型,确认 HTTP 层和模型加载都没崩。
curl -sf http://127.0.0.1:8000/v1/models > /dev/null
 
# 第二步做一次最小生成,确认真正的推理路径、鉴权和请求体格式都可用。
curl -sf http://127.0.0.1:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer token-abc123" \
  -d '{
    "model": "prod",
    "messages": [{"role": "user", "content": "Say hi in one sentence."}],
    "temperature": 0
  }' > /dev/null
客户端调用(OpenAI SDK 对接)

OpenAI 兼容服务的价值是客户端可复用:同一套调用代码既能访问云端,也能访问本地/自建的 vLLM 服务。下面示例用 OpenAI Python SDK 走 Chat Completions。

Shell
1
pip install -U openai

clients/call_openai.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from openai import OpenAI
 
# base_url 指向自建 vLLM 服务;如果以后切到云端,只需要改这里而非重写调用逻辑。
client = OpenAI(
  base_url="http://127.0.0.1:8000/v1",
  api_key="token-abc123",
)
 
# 这段请求体和标准 OpenAI Chat Completions 兼容,方便直接接已有业务代码。
resp = client.chat.completions.create(
  model="prod",
  messages=[{"role": "user", "content": "Write a haiku about debugging."}],
  temperature=0,
  max_tokens=128,
)
 
# 最终只取第一候选的文本内容;业务代码通常会在这里接入重试、超时和日志。
print(resp.choices[0].message.content)
监控与回滚(最小可操作)

闭环的监控重点放在三类信号:服务可用性(health)、吞吐与延迟(QPS/TTFT/TPOT)、以及错误率(5xx/超时/OOM)。最小回滚流程必须是“切换模型指针 + 重启服务 + 运行 smoke test”。

指标采集点(建议最小集)
  • 服务级别: /v1/models 可用性,5xx 比例,请求超时比例。
  • 推理级别:首 token 延迟(TTFT)、每 token 延迟(TPOT)、生成长度分布。
  • 资源级别:GPU 显存占用、GPU utilization、CPU/内存、队列长度。
vLLM Prometheus metrics
Shell
1
curl -sf http://127.0.0.1:8000/metrics | head
回滚脚本(最小骨架)
serving/vllm/rollback.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#!/usr/bin/env bash
# 回滚脚本必须“失败即停”,否则很容易出现模型指针已切换但服务未重启成功的半完成状态。
set -euo pipefail
 
# 回滚目标从命令行传入,例如 model_v0003。
TARGET="${1:?usage: rollback.sh model_vXXXX}"
# 统一计算项目根目录,避免脚本从不同 cwd 调用时路径失效。
ROOT="$(cd "$(dirname "$0")/../.." && pwd)"
 
# 原子更新 prod 软链接,让服务下一次启动时加载目标版本。
ln -sfn "$ROOT/models/registry/$TARGET" "$ROOT/models/prod"
 
# 具体重启方式取决于你的进程管理器(systemd/docker/k8s)
# 这里仅给出最小形态:杀进程后重启
# 如果旧进程不存在,pkill 返回非零;这里显式容忍这种情况。
pkill -f "vllm serve" || true
# 用 nohup 启动一个后台服务,真实生产更推荐交给 systemd / k8s 管理。
nohup bash "$ROOT/serving/vllm/serve.sh" > "$ROOT/outputs/vllm_stdout.log" 2>&1 &
 
# 重启后立刻做 smoke test,确保回滚并非“切了指针但服务仍不可用”。
bash "$ROOT/serving/vllm/healthcheck.sh"
后续阅读

本篇到这里结束,覆盖 AI 训练与推理编程的横向工程栈:语言与数值底座、数据管线、基础训练框架、经典机器学习、语言模型训练框架、分布式训练、模型导出、推理服务、检索/RAG、Agent 编排,以及从训练脚本到推理服务的最小闭环。

下一篇 ai-knowledge-quick-ref-7 继续展开重点框架详解与代码精读,包括 PyTorch、Transformers、PEFT、语言模型强化学习、OpenRLHF、verl、DeepSpeed、vLLM,以及手写 Transformer、Claude Code 类 agent 和 NER 算法源码解读。

← 人工智能知识 - 智能体
人工智能知识 - 编程(二) →

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code class="" title="" data-url=""> <del datetime=""> <em> <i> <q cite=""> <strike> <strong> <pre class="" title="" data-url=""> <span class="" title="" data-url="">

Related Posts

  • 吴恩达机器学习笔记
  • 人工智能知识 - 数学基础
  • LangChain: Architecture, LCEL, Agents, LangGraph, Retrieval, and Production Patterns
  • 人工智能知识 - 智能体
  • 人工智能知识 - 简介

Recent Posts

  • 人工智能知识 - 编程(二)
  • 人工智能知识 - 编程(一)
  • 人工智能知识 - 智能体
  • 人工智能知识 - Transformers和大模型
  • 人工智能知识 - 主要应用领域
ABOUT ME

汪震 | Alex Wong

江苏淮安人,现居北京。目前供职于腾讯云,专注国际化和AI落地。

GitHub:gmemcc

Git:git.gmem.cc

Email:gmemjunk@gmem.cc@me.com

ABOUT GMEM

绿色记忆是我的个人网站,域名gmem.cc中G是Green的简写,MEM是Memory的简写,CC则是我的小天使彩彩名字的简写。

我在这里记录自己的工作与生活,同时和大家分享一些编程方面的知识。

GMEM HISTORY
v2.00:微风
v1.03:单车旅行
v1.02:夏日版
v1.01:未完成
v0.10:彩虹天堂
v0.01:阳光海岸
MIRROR INFO
Meta
  • Log in
  • Entries RSS
  • Comments RSS
  • WordPress.org
Recent Posts
  • 人工智能知识 - 编程(二)
    这一篇承接人工智能知识 - 编程(一)。前一篇已经梳理 AI 训练与推理编程的横向工程栈;本篇进入重点框架详解与 ...
  • 人工智能知识 - 编程(一)
    这一篇专门处理 AI 训练、微调、推理与部署中的编程栈问题。前几篇分别讲了机器学习基础、任务版图、Transfo ...
  • 人工智能知识 - 智能体
    这一篇处理模型之外的系统层问题,包括上下文工程、Harness Engineering、检索增强生成(RAG)与 ...
  • 人工智能知识 - Transformers和大模型
    这一篇聚焦现代大模型主线,内容从 Transformer 架构出发,延伸到语言模型、多模态模型、预训练与微调,以 ...
  • 人工智能知识 - 主要应用领域
    这一篇从常用算法进入机器学习基础概念、经典机器学习与神经网络,重点讨论“模型如何被构造、训练、评估与正则化”。前 ...
  • 人工智能知识 - 算法和机器学习
    这一篇从常用算法进入机器学习基础概念、经典机器学习与神经网络,重点讨论“模型如何被构造、训练、评估与正则化”。前 ...
  • 人工智能知识 - 数学基础
    这一篇整理 AI 所需的数学基础,包括基础数学、线性代数、微积分与概率论统计。它回答的核心问题是:模型里的向量、 ...
  • 人工智能知识 - 简介
    这一篇作为整套 AI 总纲的导论,先回答更根本的问题,不急于进入公式和具体模型细节:什么叫智能,人工智能究竟在试 ...
  • 多语言敏感信息检测模型训练日志
    这篇文章记录一个多语言敏感信息识别项目的完整训练日志。它关注的是工程路径本身:原始 AI 合成语料如何被清洗成可 ...
  • DevPod on Kubernetes: turning devcontainer.json into a persistent remote workspace
    DevPod is an open source workspace manager ...
  • OpenClaw: Architecture, Components, and Deployment Notes
    Four Months, 343,000 Stars On November 24, 2025, ...
  • Replacing Docker Desktop with Colima on macOS
    Colima is one of the cleanest ways ...
  • Kubernetes GPU Sharing
    GPU sharing in Kubernetes depends on what ...
  • Investigating and Solving the Issue of Failed Certificate Request with ZeroSSL and Cert-Manager
    In this blog post, I will walk ...
  • A Comprehensive Study of Kotlin for Java Developers
    Introduction Purpose of the Study Understanding the Mo ...
  • LangChain: Architecture, LCEL, Agents, LangGraph, Retrieval, and Production Patterns
    LangChain is no longer best understood as ...
  • Kubernetes Migration
    Migrating a Kubernetes cluster from one cloud ...
  • Terraform: a practical guide to infrastructure as code
    Terraform is an infrastructure-as-code tool. You describ ...
TOPLINKS
  • Zitahli's blue 91 people like this
  • 梦中的婚礼 64 people like this
  • 汪静好 61 people like this
  • 那年我一岁 36 people like this
  • 为了爱 28 people like this
  • 小绿彩 26 people like this
  • 彩虹姐姐的笑脸 24 people like this
  • 杨梅坑 6 people like this
  • 亚龙湾之旅 1 people like this
  • 汪昌博 people like this
  • 2013年11月香山 10 people like this
  • 2013年7月秦皇岛 6 people like this
  • 2013年6月蓟县盘山 5 people like this
  • 2013年2月梅花山 2 people like this
  • 2013年淮阴自贡迎春灯会 3 people like this
  • 2012年镇江金山游 1 people like this
  • 2012年徽杭古道 9 people like this
  • 2011年清明节后扬州行 1 people like this
  • 2008年十一云龙公园 5 people like this
  • 2008年之秋忆 7 people like this
  • 老照片 13 people like this
  • 火一样的六月 16 people like this
  • 发黄的相片 3 people like this
  • Cesium学习笔记 90 people like this
  • IntelliJ IDEA知识集锦 59 people like this
  • Bazel学习笔记 38 people like this
  • 基于Kurento搭建WebRTC服务器 38 people like this
  • PhoneGap学习笔记 32 people like this
  • NaCl学习笔记 32 people like this
  • 使用Oracle Java Mission Control监控JVM运行状态 29 people like this
  • Ceph学习笔记 27 people like this
  • 基于Calico的CNI 27 people like this
Tag Cloud
ActiveMQ AspectJ CDT Ceph Chrome CNI Command Cordova Coroutine CXF Cygwin DNS Docker eBPF Eclipse ExtJS F7 FAQ Groovy Hibernate HTTP IntelliJ IO编程 IPVS JacksonJSON JMS JSON JVM K8S kernel LB libvirt Linux知识 Linux编程 LOG Maven MinGW Mock Monitoring Multimedia MVC MySQL netfs Netty Nginx NIO Node.js NoSQL Oracle PDT PHP Redis RPC Scheduler ServiceMesh SNMP Spring SSL svn Tomcat TSDB Ubuntu WebGL WebRTC WebService WebSocket wxWidgets XDebug XML XPath XRM ZooKeeper 亚龙湾 单元测试 学习笔记 实时处理 并发编程 彩姐 性能剖析 性能调优 文本处理 新特性 架构模式 系统编程 网络编程 视频监控 设计模式 远程调试 配置文件 齐塔莉
Recent Comments
  • xdemo on 人工智能知识 - 编程(二)
  • 杨松涛 on snmp4j学习笔记
  • kaka on Cilium学习笔记
  • JackZhouMine on Cesium学习笔记
  • 陈黎 on 通过自定义资源扩展Kubernetes
  • qg on Istio中的透明代理问题
  • heao on 基于本地gRPC的Go插件系统
  • 黄豆豆 on Ginkgo学习笔记
  • cloud on OpenStack学习笔记
  • 5dragoncon on Cilium学习笔记
  • Archeb on 重温iptables
  • C/C++编程:WebSocketpp(Linux + Clion + boostAsio) – 源码巴士 on 基于C/C++的WebSocket库
  • jerbin on eBPF学习笔记
  • point on Istio中的透明代理问题
  • G on Istio中的透明代理问题
  • 绿色记忆:Go语言单元测试和仿冒 on Ginkgo学习笔记
  • point on Istio中的透明代理问题
  • 【Maven】maven插件开发实战 – IT汇 on Maven插件开发
  • chenlx on eBPF学习笔记
  • Alex on eBPF学习笔记
  • CFC4N on eBPF学习笔记
  • 李运田 on 念爷爷
  • yongman on 记录一次KeyDB缓慢的定位过程
©2005-2026 Gmem.cc | Powered by WordPress | 京ICP备18007345号-2