Menu

  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay
  • Home
  • Work
    • AI
    • Cloud
      • Virtualization
      • IaaS
      • PaaS
    • Architecture
    • BigData
    • Python
    • Java
    • Go
    • C
    • C++
    • JavaScript
    • PHP
    • Others
      • Assembly
      • Ruby
      • Perl
      • Lua
      • Rust
      • XML
      • Network
      • IoT
      • GIS
      • Algorithm
      • Math
      • RE
      • Graphic
    • OS
      • Linux
      • Windows
      • Mac OS X
    • Database
      • MySQL
      • Oracle
    • Mobile
      • Android
      • IOS
    • Web
      • HTML
      • CSS
  • Life
    • Cooking
    • Travel
    • Gardening
  • Gallery
  • Video
  • Music
  • Essay

人工智能知识 - 编程

17
Apr
2026

人工智能知识 - 编程

By Alex
/ in Uncategorized
0 Comments

这一篇专门处理 AI 训练、微调、推理与部署中的编程栈问题。前几篇分别讲了机器学习基础、任务版图、Transformer 与上下文工程;这一篇转向“代码层面的真实系统”:从 NumPy、数据管线、训练框架、分布式组件,到推理引擎、向量检索、服务化接口与工程辅助库,梳理一条从实验脚本到线上推理系统的完整技术链。

语言与数值计算底座

训练与推理系统最终都会落到“把数据变成数组/张量,然后在有限内存和带宽下完成大量数值运算”。这一层底座看起来朴素,但它决定了三个硬指标:吞吐(Throughput)、内存占用(Memory Footprint)与拷贝次数(Copy Count)。NumPy/Arrow/Parquet 一类组件在工程上通常承担训练数据管线、离线特征、评测集加工与推理输入输出的基础角色。

数组与科学计算
NumPy

NumPy(Numerical Python)定义了 Python 生态里最通用的数组语义:shape / stride / dtype / broadcasting。训练脚本里大量“数据预处理、采样、拼接、统计、离线特征生成、指标计算”都在直接使用这些概念,即使模型训练本身在 PyTorch/JAX 上完成。

安装通常直接用 pip install numpy。生产/研究环境常用 conda-forge 统一 BLAS 与二进制依赖,以减少 ABI 问题。

训练/推理工程中的 NumPy API 速查
主题 常用对象/函数 工程意义 典型用法
创建与视图 np.array / np.asarray / np.frombuffer 控制是否拷贝,决定数据是否可被零拷贝接入后续框架
Python
1
2
3
4
5
import numpy as np
 
x = np.asarray([1, 2, 3], dtype=np.int32)   # 尽量不拷贝
buf = b"\x01\x00\x00\x00\x02\x00\x00\x00"
y = np.frombuffer(buf, dtype=np.int32)      # 直接视图到 bytes buffer
重排与拼接 reshape / transpose / concatenate / stack reshape 通常是视图;transpose 多为视图但会改变 stride;拼接常产生新拷贝
Python
1
2
3
a = np.arange(12).reshape(3, 4)   # 视图
at = a.T                          # 视图(但 stride 变化)
b = np.concatenate([a, a], axis=0)  # 新数组(拷贝)
dtype 与数值稳定 astype / np.float16 / np.float32 / np.int32 / np.uint8 数据管线 dtype 决定 IO 体积与解码成本;训练阶段通常需要与框架 dtype 对齐
Python
1
2
x = np.random.randn(1024).astype(np.float32)
x16 = x.astype(np.float16)        # 体积减半,但精度下降
拷贝控制 np.ascontiguousarray / np.copy / ndarray.flags 训练/推理 kernel 往往偏好连续内存;必要时显式转 contiguous,避免隐式拷贝
Python
1
2
3
x = np.arange(12).reshape(3, 4).T
assert not x.flags["C_CONTIGUOUS"]
y = np.ascontiguousarray(x)       # 明确拷贝一次,换取后续算子稳定与更快
广播与矢量化 np.newaxis / np.expand_dims / np.broadcast_to 避免 Python for-loop,把循环移到 C 内核,降低解释器开销
Python
1
2
3
X = np.random.randn(8, 128).astype(np.float32)
mu = X.mean(axis=0)               # (128,)
X0 = X - mu[None, :]              # broadcasting
跨框架零拷贝互操作 numpy.from_dlpack 训练/推理中常需要在 NumPy、PyTorch、JAX、CuPy 之间交换数组;DLPack 是统一协议入口
Python
1
2
3
4
5
import numpy as np
 
# x 只要实现了 __dlpack__ 协议,NumPy 就能零拷贝视图接入(通常为 view)
# x 可能来自 torch / jax / cupy 等
arr = np.from_dlpack(x)
SciPy

SciPy(Scientific Python)在深度学习训练主循环中出现频率不高,但它在三个位置仍很常见:离线优化与拟合(例如曲线拟合、数值优化)、稀疏矩阵与图算法(构图、归一化、谱方法)、统计分布与检验(评测与数据分析)。工程上,SciPy 更适合被当作“离线数值工具箱”,而不是在线训练依赖。

SciPy 模块级入口
模块 解决的问题 训练/推理场景示例
scipy.sparse 稀疏矩阵表示与运算 大规模 one-hot / 图结构 / 线性代数预处理
scipy.optimize 数值优化、拟合、求根 离线拟合校准曲线、做超参搜索的子模块
scipy.special 特殊函数 实现某些损失/分布的参考实现或数值校验
scipy.stats 统计分布与检验 实验评估、A/B test 离线分析、置信区间计算
数组元信息与计算语义

训练与推理中常见的性能与正确性问题,经常来自数组元信息被误解:隐式拷贝、错误广播或错误 dtype 会在数据规模上来后被迅速放大。下列四个概念决定了“这块数据在内存里是什么形状、如何被解释、算子如何访问”。

shape

shape 是每一维的长度。训练数据的 shape 规划通常先于模型:batch 维、序列维、通道维的放置会直接影响 broadcasting、拼接策略与 kernel 访问模式。

Python
1
2
3
import numpy as np
 
x = np.zeros((batch, seq_len, hidden), dtype=np.float32)
stride

stride 描述“沿每一维移动 1 步,需要在底层 buffer 上跳过多少字节”。它解释了为什么很多 reshape/transpose 是 O(1) 视图,以及为什么某些看似简单的切片会导致后续算子不得不拷贝成 contiguous。stride 也是 DLPack 协议定义的核心之一。

Python
1
2
3
4
5
6
7
import numpy as np
 
a = np.arange(12, dtype=np.int32).reshape(3, 4)
print(a.shape, a.strides)    # (3, 4) (16, 4)  以 int32 为例,步长单位是字节
 
at = a.T
print(at.shape, at.strides)  # (4, 3) (4, 16)  转置后 stride 对调
dtype

dtype 决定了每个元素的解释方式与字节数。训练与推理中,dtype 的作用不止“精度高低”,还包括:IO 体积、缓存命中率、向量化指令路径、以及与下游框架的类型兼容性。实践上常见的约束是:数据管线侧用更紧凑的整型/字节型存储,进入训练前再一次性转换到框架需要的 dtype。

Python
1
2
3
4
import numpy as np
 
# 例:原始 token id 通常用 int32 或更小的无符号整型存储
ids = np.array([1, 2, 3, 4], dtype=np.int32)
broadcasting

broadcasting 是“不同 shape 的数组做逐元素运算时,如何对齐维度并隐式扩展”。它是把 Python 循环消掉的关键机制,但也可能引入隐藏的大中间张量或错误对齐。广播规则的工程实践通常围绕两件事:显式插入维度(None/newaxis)与显式对齐最后几维。

Python
1
2
3
4
5
6
7
8
9
10
import numpy as np
 
# (B, T, H) - (H,) -> (B, T, H)
X = np.random.randn(2, 3, 4).astype(np.float32)
mu = X.mean(axis=(0, 1))           # (H,)
X0 = X - mu                        # broadcasting
 
# 显式插维更直观
mu2 = mu[None, None, :]            # (1, 1, H)
X1 = X - mu2
表格与列式数据
Pandas

Pandas 在训练/推理工程里的定位更接近“数据分析与小中规模表格处理”,而不是大规模训练数据读取。它擅长做数据清洗、统计分析、对齐 join、特征表合并,以及输出可审计的中间结果(CSV/Parquet)。当数据规模接近或超过内存时,工程上通常会转向 Arrow/Polars 的流式与列式路径。

Pandas 关键接口
接口 作用 备注 示例
pd.read_parquet 读取 Parquet 到 DataFrame 适合小到中规模;大规模更推荐 Polars/Arrow dataset
Python
1
2
import pandas as pd
df = pd.read_parquet("train.parquet")
DataFrame.to_parquet 写 Parquet 用于落盘中间表、特征表、评测样本
Python
1
df.to_parquet("out.parquet", index=False)
DataFrame.to_numpy 转 NumPy 数组 copy=False 不保证零拷贝,混合 dtype 会触发类型提升与拷贝
Python
1
x = df[["a", "b"]].to_numpy(dtype="float32", copy=False)
Polars

Polars 的优势在于其 lazy 执行与 streaming:把计算表达成查询计划,先做优化(projection/predicate pushdown),再以批方式执行。对训练数据预处理而言,这意味着可以在不把全量数据 materialize 到内存的情况下完成筛选、投影、采样、分桶、写回 Parquet。

Polars 常用入口
入口 意义 示例
pl.scan_parquet lazy 方式扫描 Parquet,不立即 materialize
Python
1
2
import polars as pl
lf = pl.scan_parquet("data/*.parquet")
LazyFrame.filter/select/group_by 构建查询计划
Python
1
2
3
4
5
out = (
    lf.filter(pl.col("lang") == "zh")
      .select(["text", "label"])
      .collect(streaming=True)
)
LazyFrame.sink_parquet 把 lazy 结果直接写回 Parquet
Python
1
lf.sink_parquet("out.parquet")
PyArrow

PyArrow 是 Python 对 Arrow 内存格式与生态能力的主要入口。对 AI 训练/推理工程而言,Arrow 的核心价值是统一列式内存表示,并在 Pandas、Parquet、HF Datasets 与各种 IPC 路径之间提供高效桥梁。训练数据若最终要落成 Arrow/Parquet,通常建议在预处理阶段就尽量保留 Arrow Table / RecordBatch 语义,避免频繁在 Python 对象列表与 DataFrame 之间来回转换。

PyArrow 常用入口
对象/函数 用途 示例
pa.Table 列式表,适合大规模批处理
Python
1
2
import pyarrow as pa
tbl = pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]})
pyarrow.parquet.read_table 读 Parquet 为 Table
Python
1
2
import pyarrow.parquet as pq
tbl = pq.read_table("train.parquet", columns=["text", "label"])
Table.to_pandas 转 Pandas
Python
1
df = tbl.to_pandas()
Parquet

Parquet 是面向分析与批处理的列式文件格式。训练数据落盘选择 Parquet 的理由通常是:压缩比高、列裁剪成本低、能按列读取并减少 IO、天然支持 row group 作为大文件分块单位。工程上最常见的实践是:把可训练字段放在少数列里,并显式按任务选择 columns 读取,避免把无关字段搬进内存。

Python
1
2
3
4
5
6
import pyarrow.parquet as pq
 
tbl = pq.read_table(
  "train.parquet",
  columns=["text", "label"],   # 列裁剪
)
PyArrow IPC 与 memory_map

IPC(Inter-Process Communication)格式用于把 Arrow 的内存表示序列化为文件或流,并支持高效读写。对训练数据管线而言,一个关键工程点是:若输入源支持零拷贝读取(例如 memory map),则读出来的 batch 可以保持零拷贝路径,从而显著降低 CPU 端的内存分配与拷贝开销。

Python
1
2
3
4
5
6
7
8
import pyarrow as pa
 
# 以 memory map 方式打开文件,避免额外 read() 复制
source = pa.memory_map("dataset.arrow", mode="r")
reader = pa.ipc.open_file(source)
 
batch0 = reader.get_batch(0)
tbl = pa.Table.from_batches([batch0])
序列化与权重格式
通用序列化格式

训练与推理系统里最常见的序列化对象是:配置、元数据、索引与权重。配置层常见 JSON/YAML;权重层需要关注安全性与加载速度;跨进程/跨服务通信则常用 protobuf 这类 IDL 驱动格式。

JSON

JSON 适合可读性强的元数据与小体量配置,常用于数据集 manifest、评测记录与简单索引。

YAML

YAML 常用于训练配置,但它过于通用,本文只把它视为“配置载体”。具体配置系统(Hydra/OmegaConf)在后续章节展开。

pickle

pickle 能序列化 Python 对象,但它不适合用于不可信来源的权重与模型文件,因为反序列化会执行对象构造逻辑。工程上如果需要“安全的张量权重格式”,通常会优先选择 safetensors。

protobuf

protobuf 的优势是 schema 驱动与跨语言:用 .proto 定义消息结构,由 protoc 生成多语言代码。它常用于模型服务、日志/Tracing、任务队列与数据交换协议。

1
2
3
4
# 典型流程
# 1) 定义 schema: message Foo { ... }
# 2) protoc --python_out=. foo.proto
# 3) Python 中 import foo_pb2 并读写消息
safetensors

safetensors 是面向模型权重的安全、快速格式,设计目标是替代基于 pickle 的不安全权重存储。它的工程优势主要体现在三点:加载速度、零拷贝读取路径、以及避免反序列化执行任意代码。若训练产物需要在多环境分发或上线部署,safetensors 往往是默认优先选项。

safetensors.torch API 速查
函数 用途 示例
safetensors.torch.save_file 保存 tensor dict
Python
1
2
3
4
5
from safetensors.torch import save_file
import torch
 
tensors = {"w": torch.randn(2, 3)}
save_file(tensors, "model.safetensors")
safetensors.torch.load_file 加载为 CPU tensor dict
Python
1
2
3
from safetensors.torch import load_file
 
tensors = load_file("model.safetensors")
safetensors.safe_open 按需读取,支持只取部分 key
Python
1
2
3
4
from safetensors import safe_open
 
with safe_open("model.safetensors", framework="pt", device="cpu") as f:
    w = f.get_tensor("w")
GGUF

GGUF(GGML Universal File)是 llama.cpp 生态的权重文件格式,目标是单文件、可扩展、可 memory-map,并携带足够的 KV 元数据支持推理 runtime 直接加载。若部署路线包含 llama.cpp / Ollama 一类本地推理栈,训练产物通常需要在 Hugging Face 权重与 GGUF 之间做一次转换与量化。

Shell
1
2
# llama.cpp 仓库提供 convert_*.py 脚本把 Hugging Face 权重转换为 GGUF
python convert_hf_to_gguf.py --outfile out.gguf /path/to/hf_model_dir
数据管线与预处理组件

训练与推理系统的性能瓶颈经常不在模型前向,而在输入:数据从磁盘/对象存储进入进程、被解码与清洗、被分词与组 batch、再进入 GPU。一个可用的数据管线需要同时满足三件事:吞吐(喂满设备)、一致性(可复现、可回放)、可运维(能增量、能恢复、能追溯)。

本节把数据管线拆成四层:读取抽象(Dataset/DataLoader)、存储与后端(Arrow/Parquet/WebDataset/LMDB/HDF5/mmap)、文本入口(tokenizer 与中文预处理)、以及离线预处理模式(多进程 + shard 流式写入)。每一层都以“如何写代码把数据喂进训练/推理”为主线。

数据集抽象与读取方式
PyTorch 数据接口

PyTorch 的数据读取围绕两个 Dataset 协议展开:map-style(可随机访问)与 iterable-style(顺序流式)。DataLoader 负责把 Dataset 变成可迭代的 batch 流,并提供多进程 worker、prefetch、pin memory 等机制。实践中,Dataset 负责“怎样得到一个样本”,DataLoader 负责“怎样并行、怎样组批、怎样把样本送进设备”。

安装:

Shell
1
pip install torch
Dataset

map-style Dataset 的约束很简单:实现 __getitem__ 与 __len__。它适合“样本天然有索引”的存储,例如:一个样本一行的 Parquet/Arrow、固定条目 LMDB、按文件名索引的图像文件夹。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
 
    def __len__(self):
        return len(self.paths)
 
    def __getitem__(self, idx):
        # 返回 dict/tuple 都可以,关键是 collate_fn 能处理
        path = self.paths[idx]
        with open(path, "rb") as f:
            blob = f.read()
        return {"path": path, "blob": blob}
IterableDataset

IterableDataset 只要求实现 __iter__,更适合训练数据远大于本地磁盘、需要顺序扫描或在线生成的场景(对象存储流式、Kafka/队列、WebDataset tar 流、动态合成数据)。使用多 worker 时,必须自行做切分(worker shard),否则每个 worker 都会重复遍历同一份流。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from torch.utils.data import IterableDataset
 
class LineStream(IterableDataset):
    def __init__(self, filename):
        self.filename = filename
 
    def __iter__(self):
        worker = torch.utils.data.get_worker_info()
        wid = 0 if worker is None else worker.id
        wnum = 1 if worker is None else worker.num_workers
 
        with open(self.filename, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                if (i % wnum) != wid:
                    continue
                yield {"text": line.rstrip("\n")}
DataLoader

DataLoader 的关键价值在于把“批处理策略”和“并行读取策略”显式化。它的典型参数包括: batch_size、 shuffle、 num_workers、 collate_fn、 pin_memory、 prefetch_factor、 persistent_workers。这些参数共同决定吞吐、延迟、内存占用与稳定性。

对象 / API 用途 关键参数 典型用法
torch.utils.data.DataLoader 把 Dataset/IterableDataset 变成 batch 流 batch_size / num_workers / collate_fn / pin_memory
Python
1
2
3
4
5
6
7
8
9
10
from torch.utils.data import DataLoader
 
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
)
torch.utils.data.get_worker_info 在 IterableDataset 中分片 worker.id / worker.num_workers
Python
1
2
3
worker = torch.utils.data.get_worker_info()
wid = 0 if worker is None else worker.id
wnum = 1 if worker is None else worker.num_workers
Hugging Face Datasets

Datasets 把“数据集”抽象成 Arrow-backed 的 Dataset/ DatasetDict,并提供统一的加载(Hub/本地/通用 builder)、变换(map/filter)、以及落盘(save_to_disk / parquet)的工具链。它在大模型训练管线中的典型用法是:用一次离线 map 把清洗与分词做掉,输出可 memory-map 的 Arrow/Parquet,再用 PyTorch DataLoader 做高吞吐训练。

安装:

Shell
1
pip install datasets

加载:支持 Hub 数据集、目录内 CSV/JSON/Parquet 文件、以及通用 builder(例如 json / parquet / webdataset)。

Python
1
2
3
4
5
6
7
8
9
10
from datasets import load_dataset
 
# Hub 数据集
ds = load_dataset("allenai/c4", "en", split="train")
 
# 本地 Parquet 目录
ds = load_dataset("parquet", data_files={"train": ["./data/train-*.parquet"]})["train"]
 
# streaming:不落盘,边读边迭代
ds_stream = load_dataset("json", data_files="s3://bucket/data.jsonl", streaming=True)["train"]
API 用途 关键参数 用法示例
datasets.load_dataset 加载 Hub / 本地 / 通用 builder data_files / split / streaming / num_proc
Python
1
2
3
4
5
6
ds = load_dataset(
    "json",
    data_files="train.jsonl",
    split="train",
    num_proc=8,
)
Dataset.map 清洗/分词/特征工程 batched / num_proc / remove_columns
Python
1
2
3
4
5
6
7
8
9
def tok(batch):
    return tokenizer(batch["text"])
 
ds2 = ds.map(
    tok,
    batched=True,
    num_proc=8,
    remove_columns=["text"],
)
Dataset.save_to_disk 保存 Arrow 数据集目录 输出目录
Python
1
ds2.save_to_disk("./out/ds_tok")
datasets.load_from_disk 恢复已保存数据集 输入目录
Python
1
2
from datasets import load_from_disk
ds2 = load_from_disk("./out/ds_tok")
大规模数据读取后端

当单机磁盘与单个文件格式无法满足吞吐或并行度时,训练数据会落到“更工程化的后端”。最常见的三类:tar shard(WebDataset)、KV store(LMDB)、块存储/层次结构(HDF5)。它们的核心不是“更通用”,而是“更匹配训练读取模式”。

WebDataset

WebDataset 以 tar shard 作为基本载体,强调流式读取与链式 pipeline。它常用于大规模图像/视频/音频/多模态训练:样本被打包成许多 tar 文件(shard),训练时按 shard 流式拉取、解码、组 batch。安装:

Shell
1
pip install webdataset
API 用途 用法示例
webdataset.WebDataset 构建数据管线(DataPipeline + 流式操作)
Python
1
2
3
4
5
6
7
8
9
import webdataset as wds
 
dataset = (
    wds.WebDataset("shards/data-{000000..000127}.tar")
      .shuffle(10000)
      .decode("pil")
      .to_tuple("jpg", "txt")
      .batched(64)
)
FluidInterface.with_epoch 限制一个 epoch 的样本数(类似 islice)
Python
1
dataset = dataset.with_epoch(1_000_000)
LMDB

LMDB 是 memory-mapped 的 KV store,优势是读性能稳定、并发读友好,适合“样本就是 key->value”的训练数据(尤其是大量小对象)。LMDB 的关键约束是事务:读写必须在 transaction 内进行,并且从 LMDB 返回的值可能直接指向 mmap 区域,transaction 结束后不可继续使用该指针。

安装:

Shell
1
pip install lmdb
API 用途 用法示例
lmdb.open 打开/创建环境(Environment)
Python
1
2
3
4
5
6
7
8
import lmdb
env = lmdb.open(
    "./data.lmdb",
    map_size=1024**4,  # 1TB 逻辑上限
    subdir=False,
    readonly=False,
    lock=True,
)
Environment.begin 开启事务(Transaction)
Python
1
2
with env.begin(write=True) as txn:
    txn.put(b\"k\", b\"v\")
Transaction.get / put 读写 KV
Python
1
2
with env.begin(write=False) as txn:
    v = txn.get(b\"k\")
Transaction.cursor 迭代遍历
Python
1
2
3
4
with env.begin() as txn:
    with txn.cursor() as cur:
        for k, v in cur:
            ...
HDF5

HDF5 适合块状数组与层次结构数据。训练里常见于科学计算数据、时序/医疗影像、以及需要 chunked 存储与压缩的场景。Python 侧最常用的是 h5py。安装:

Shell
1
pip install h5py
API 用途 用法示例
h5py.File 打开/创建文件
Python
1
2
import h5py
f = h5py.File(\"data.h5\", \"r\")
Group.create_dataset 创建 dataset(可设 chunks/compression)
Python
1
2
3
4
5
6
7
d = f.create_dataset(
    \"x\",
    shape=(0, 4096),
    maxshape=(None, 4096),
    chunks=(1024, 4096),
    dtype=\"int32\",
)
Dataset.resize 追加写入(配合 maxshape)
Python
1
2
3
n = d.shape[0]
d.resize((n + batch.shape[0], 4096))
d[n:] = batch
内存映射

memory map 的价值在于把“磁盘 IO + 反序列化”变成“按页缺页加载”:进程只在访问到某段数据时才触发读取,并允许多个进程共享同一份文件缓存。它常用于 Arrow IPC、NumPy 的大数组、以及只读数据集的多 worker 读取。

Python
1
2
3
4
5
import numpy as np
 
# 以 memmap 读一个巨大 float32 数组(示例)
arr = np.memmap(\"x.bin\", dtype=np.float32, mode=\"r\")
# arr[i] 的访问才会触发对应页的加载
文本分词与 tokenizer 组件

分词(Tokenization)有两种工程形态:在线分词(推理时对用户输入分词)与离线分词(训练前把语料转成 token id)。离线分词的目标是把训练阶段的 CPU 开销外移:训练时直接读取 input_ids/ attention_mask 之类张量,避免每步都做字符串处理。

tokenizers

Tokenizers 是 Rust 实现的 tokenizer 库,提供训练、编码、解码以及 padding/truncation 等预处理步骤。它面向生产:同一套 tokenizer 可以被训练脚本、离线预处理作业与线上服务复用。

安装:

Shell
1
pip install tokenizers
API 用途 用法示例
tokenizers.Tokenizer tokenizer 管线对象
Python
1
2
3
4
5
6
7
8
9
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
 
tok = Tokenizer(BPE(unk_token=\"[UNK]\"))
tok.pre_tokenizer = Whitespace()
trainer = BpeTrainer(vocab_size=32000, special_tokens=[\"[UNK]\", \"[PAD]\", \"[BOS]\", \"[EOS]\"])
tok.train([\"corpus.txt\"], trainer)
Tokenizer.encode / encode_batch 编码为 token id
Python
1
2
3
4
5
enc = tok.encode(\"hello world\")
ids = enc.ids
 
encs = tok.encode_batch([\"a\", \"b\"])
batch_ids = [e.ids for e in encs]
Tokenizer.decode / decode_batch 解码
Python
1
2
text = tok.decode(ids)
texts = tok.decode_batch(batch_ids)
SentencePiece

SentencePiece 既是一套 tokenizer 算法(BPE / Unigram LM),也是训练与推理工具链。它的工程优势在于“直接在原始文本上训练”,不依赖预分词,对无空格语言更友好。

安装:

Shell
1
pip install sentencepiece
API 用途 用法示例
sentencepiece.SentencePieceTrainer.Train 训练 spm 模型
Python
1
2
3
4
5
6
7
8
9
import sentencepiece as spm
 
spm.SentencePieceTrainer.Train(
    input=\"corpus.txt\",
    model_prefix=\"spm\",
    vocab_size=32000,
    model_type=\"bpe\",
    character_coverage=0.9995,
)
sentencepiece.SentencePieceProcessor 加载并编码/解码
Python
1
2
3
4
5
6
import sentencepiece as spm
 
sp = spm.SentencePieceProcessor()
sp.load(\"spm.model\")
ids = sp.encode(\"你好世界\", out_type=int)
text = sp.decode(ids)
tiktoken

tiktoken 是 OpenAI 开源的 BPE tokenizer 实现,常用于与 OpenAI 模型兼容的 token 计数与编码。它提供按 encoding 名称或按模型名选择 encoding 的接口。

安装:

Shell
1
pip install tiktoken
API 用途 用法示例
tiktoken.get_encoding 按 encoding 名称获取
Python
1
2
import tiktoken
enc = tiktoken.get_encoding(\"o200k_base\")
tiktoken.encoding_for_model 按模型名获取
Python
1
enc = tiktoken.encoding_for_model(\"gpt-4o\")
Encoding.encode / decode 编码/解码
Python
1
2
ids = enc.encode(\"hello world\")
text = enc.decode(ids)
中文文本预处理与分词工具

在大模型训练中,中文通常直接走子词/字节级 tokenizer;但在传统 NLP、搜索、实体抽取、以及“数据清洗与规范化”阶段,中文分词与繁简转换仍是高频工程环节。

jieba

安装:

Shell
1
pip install jieba
API 用途 用法示例
jieba.cut 分词(generator)
Python
1
2
import jieba
tokens = list(jieba.cut(\"我爱自然语言处理\"))
jieba.lcut 分词(list)
Python
1
tokens = jieba.lcut(\"我爱自然语言处理\")
jieba.cut_for_search 搜索引擎模式(更细粒度)
Python
1
tokens = list(jieba.cut_for_search(\"南京市长江大桥\"))
jieba.add_word 动态加入词典
Python
1
jieba.add_word(\"大语言模型\")
jieba.load_userdict 加载用户词典
Python
1
jieba.load_userdict(\"userdict.txt\")
opencc-python

opencc-python 是早期 OpenCC 的 Python wrapper,版本较旧。现代工程更常用维护更活跃的 OpenCC 包。这里保留两条安装路径:兼容旧 wrapper 与直接使用 OpenCC。

Shell
1
2
3
4
5
# 旧 wrapper(较旧)
pip install opencc-python
 
# 推荐:OpenCC(维护更活跃)
pip install OpenCC

Python
1
2
3
4
# OpenCC 示例:繁转简
from opencc import OpenCC
cc = OpenCC("t2s")
out = cc.convert("今天天氣不錯")
批处理构造与样本拼接

batch 构造的关键目标是把“变长样本”转成“规则张量”,并尽量减少无效计算。四个高频机制:collator(怎么把样本列表变成 batch)、padding(对齐长度)、packing(把多个短样本拼到一个长序列里)、masking(构造 loss 的可学习位置)。

batch 构造机制
collator

collator 通常以 collate_fn 形式接入 DataLoader,负责把 List[sample] 转为张量 batch。

padding

padding 的核心是把不同长度序列补齐,并同步产生 attention_mask。对于 encoder 任务,padding 的位置通常 mask 掉注意力;对于 decoder 任务,还要考虑因果 mask 与 label mask。

packing

packing 把多个短序列拼接到固定长度 block 中,减少 padding 浪费。它适合预训练与指令微调中的“很多短样本”场景,但需要正确构造 label 与分段边界(例如用 special token 分隔)。

masking

masking 用来指定 loss 只在哪些位置计算,例如 causal LM 的 label shift、MLM 的随机 mask、SFT 中把提示词部分的 label 设为 ignore。

多模态样本组织与 processor

多模态模型往往通过 processor 把文本 tokenizer 与视觉/音频预处理封装在一起。工程上需要保证:离线预处理与线上推理使用同一套 processor 配置,避免“训练时的输入分布”和“推理时的输入分布”不一致。

合成数据与数据增强工具

合成数据通常不替代真实数据,而是用来补齐边界覆盖:格式多样性、语言多样性、脏数据模式、罕见实体组合、以及隐私合规场景下的脱敏替身。合成数据要能回放:生成种子、版本、配置都要进入产物元数据。

结构化合成数据生成
Faker

安装:

Shell
1
pip install Faker
names-dataset

安装:

Shell
1
pip install names-dataset
pycountry

安装:

Shell
1
pip install pycountry
库 核心 API 用法示例
Faker Faker / Faker.seed
Python
1
2
3
4
from faker import Faker
Faker.seed(42)
fake = Faker(locale=\"zh_CN\")
row = {\"name\": fake.name(), \"addr\": fake.address()}
names-dataset NameDataset
Python
1
2
3
from names_dataset import NameDataset
nd = NameDataset()
info = nd.search(\"Zoe\")
pycountry pycountry.countries / lookup
Python
1
2
3
import pycountry
cn = pycountry.countries.lookup(\"China\")
langs = pycountry.languages.get(alpha_2=\"zh\")
语言识别与类型检测库

语言识别与类型检测经常用于预处理阶段的路由:多语言混杂数据的分桶、代码/文档/日志的分流、以及不同清洗规则的选择。工程重点是“低成本、可解释、可复现”,而不是把检测做得像论文 benchmark 一样精细。

自然语言识别
库 安装 核心 API 用法示例
langdetect
Shell
1
pip install langdetect
detect / detect_langs / DetectorFactory.seed
Python
1
2
3
4
from langdetect import detect, detect_langs, DetectorFactory
DetectorFactory.seed = 0
lang = detect(\"Hello world\")
langs = detect_langs(\"Otec matka syn.\")
fastText lid
Shell
1
pip install fasttext
fasttext.load_model / model.predict
Python
1
2
3
4
5
6
7
8
9
import fasttext
from huggingface_hub import hf_hub_download
 
model_path = hf_hub_download(
    repo_id=\"facebook/fasttext-language-identification\",
    filename=\"model.bin\",
)
model = fasttext.load_model(model_path)
labels, probs = model.predict(\"Hello, world!\")
lingua-language-detector
Shell
1
pip install lingua-language-detector
LanguageDetectorBuilder
Python
1
2
3
from lingua import LanguageDetectorBuilder
detector = LanguageDetectorBuilder.from_all_languages().build()
lang = detector.detect_language_of(\"Hello world\")
文件类型检测

文件类型检测的目标是“尽量早发现二进制/压缩/不支持格式”,避免把不可解析内容送进后续清洗链路。python-magic 依赖底层 libmagic,需要系统依赖到位。

库 / API 安装 核心函数 用法示例
python-magic
Shell
1
2
3
pip install python-magic
# Debian/Ubuntu:
sudo apt-get install libmagic1
magic.from_file / magic.from_buffer / magic.Magic
Python
1
2
3
import magic
mime = magic.from_file(\"a.pdf\", mime=True)
sig = magic.from_buffer(open(\"a.pdf\", \"rb\").read(2048))
源码语言检测

源码语言检测用于代码数据集清洗(按语言分桶、去除 vendored/generated)与语法级预处理(解析 AST、提取符号)。Pygments 偏启发式 lexer;tree-sitter 提供结构化解析(AST)。

库 安装 核心 API 用法示例
Pygments
Shell
1
pip install pygments
get_lexer_for_filename / get_lexer_by_name
Python
1
2
from pygments.lexers import get_lexer_for_filename
lexer = get_lexer_for_filename(\"a.py\")
tree-sitter
Shell
1
pip install tree-sitter tree-sitter-python
Language / Parser
Python
1
2
3
4
5
6
from tree_sitter import Language, Parser
import tree_sitter_python as tspython
 
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
tree = parser.parse(b\"print('hi')\\n\")

“Linguist 类方案”通常指 GitHub Linguist 的启发式:文件扩展名 + shebang + 内容特征 + 语言冲突消歧规则。实际工程里常把这类规则作为“分桶的第一步”,再对少量不确定样本做更重的解析。

大规模离线预处理模式

离线预处理的目标是把昂贵的 CPU 工作(解析、清洗、分词、格式规范化)集中到一次批处理作业中,并把结果写成稳定、可复用、可 memory-map 的 shard。一个可运维的离线预处理作业至少要具备:可重跑、可断点续跑、单 shard 失败不影响全局、输出具备 manifest。

多进程 Pool

多进程的核心收益在于绕过 Python GIL,把 CPU 密集的分词与解析并行化。进程间传输大对象会迅速放大序列化开销,因此更稳妥的做法是把输入切成可流式读取的小对象,把输出写成 shard。

shard 流式写入

shard 是离线预处理的基本单元:每个 shard 控制大小(例如 512MB~2GB)、可单独校验、可单独重跑。WebDataset 的 shard 是 tar;HF Datasets/Arrow 的 shard 是 parquet/arrow 文件集合;LMDB/HDF5 则是数据库/容器文件。

内存安全与失败恢复
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
import os
from multiprocessing import Pool
 
def process_line(line: str) -> str:
    # 这里放:解析、清洗、分词、规范化
    obj = json.loads(line)
    obj[\"text\"] = obj[\"text\"].strip()
    return json.dumps(obj, ensure_ascii=False)
 
def write_shards(in_path: str, out_dir: str, shard_lines: int = 200_000, workers: int = 8):
    os.makedirs(out_dir, exist_ok=True)
 
    def shard_path(i: int) -> str:
        return os.path.join(out_dir, f\"shard-{i:06d}.jsonl\")
 
    with open(in_path, \"r\", encoding=\"utf-8\") as f, Pool(processes=workers) as pool:
        shard_idx = 0
        buf = []
        for out in pool.imap(process_line, f, chunksize=256):
            buf.append(out)
            if len(buf) >= shard_lines:
                tmp = shard_path(shard_idx) + \".tmp\"
                with open(tmp, \"w\", encoding=\"utf-8\") as wf:
                    wf.write(\"\\n\".join(buf) + \"\\n\")
                os.replace(tmp, shard_path(shard_idx))  # 原子替换,避免半写文件
                buf.clear()
                shard_idx += 1
 
        if buf:
            tmp = shard_path(shard_idx) + \".tmp\"
            with open(tmp, \"w\", encoding=\"utf-8\") as wf:
                wf.write(\"\\n\".join(buf) + \"\\n\")
            os.replace(tmp, shard_path(shard_idx))
基础训练框架

基础训练框架(Foundational Training Framework)提供三类不可替代的底座能力:张量与设备执行(Tensor & Device Execution)、自动求导(Automatic Differentiation)、以及训练循环所需的基础组件(模块、优化器、数据加载、序列化)。上层训练框架可以封装流程,但底层的梯度、显存与 kernel 行为最终仍由这一层决定。

PyTorch

PyTorch 的编程模型是“Python 先行的动态图(Dynamic Graph)+ 胶带式自动求导(Tape-based Autograd)”。这使训练循环天然可调试:前向是普通 Python 代码,反向由 autograd 记录并回放。工程上更重要的是:PyTorch 生态已经把训练、分布式、编译优化与部署衔接做成了一套可组合部件,既能写研究型循环,也能写生产训练栈。

安装建议遵循官方安装页的选择器:CPU-only 与 CUDA 版本的 pip/conda 命令需要与目标机器驱动、CUDA 版本匹配。CPU 环境通常可以直接 pip install torch;GPU 环境应按官方给出的 index-url 安装对应 CUDA wheel。安装后验证建议至少覆盖两点:能创建张量以及GPU 可见(如适用)。

PyTorch installation (pattern)
Shell
1
2
3
4
5
6
# CPU (typical)
pip install -U torch
 
# CUDA: use the command generated by https://pytorch.org/get-started/locally/
# It typically looks like:
# pip install -U torch --index-url https://download.pytorch.org/whl/cu12x

PyTorch quick verify
Python
1
2
3
4
5
import torch
print(torch.__version__)
x = torch.randn(2, 3)
print(x.shape)
print("cuda_available:", torch.cuda.is_available())
核心抽象

三件事决定 PyTorch 训练代码的结构:张量(Tensor)承载数据与参数、autograd 负责梯度、 nn.Module 组织可训练子图并提供参数与缓冲区的可追踪结构。

Tensor

Tensor 是所有计算的基本载体。训练相关的关键元信息有四类:形状(shape)、数值类型(dtype)、设备(device)与梯度开关(requires_grad)。其中设备与 dtype 会直接改变 kernel 路径与显存占用;requires_grad 决定该张量是否会成为计算图的一部分。

对象/函数 用途 典型用法
torch.tensor 从 Python 对象构造张量
Python
1
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
torch.randn 随机初始化(参数/输入常用)
Python
1
w = torch.randn(768, 768, device="cuda")
Tensor.to 设备/精度迁移
Python
1
x = x.to(device="cuda", dtype=torch.bfloat16)
Tensor.requires_grad_ 把张量标为需要梯度
Python
1
x = x.requires_grad_(True)
torch.no_grad 推理/评估时关闭梯度跟踪
Python
1
2
with torch.no_grad():
    y = model(x)
torch.inference_mode 比 no_grad 更强的推理模式(更少开销)
Python
1
2
with torch.inference_mode():
    y = model(x)
autograd

autograd 的编程要点是“图从前向构建,梯度在反向回传”。多数训练代码只用到 loss.backward() + optimizer.step(),但当需要更复杂的梯度形态(如多目标、梯度惩罚、二阶项)时,就会显式使用 torch.autograd.grad。

对象/函数 用途 典型用法
Tensor.backward 反向传播,累计梯度到 leaf 参数
Python
1
2
3
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.autograd.grad 函数式求梯度,返回梯度张量而不写入 .grad
Python
1
grads = torch.autograd.grad(loss, model.parameters(), create_graph=False)
torch.autograd.Function 自定义前向/反向(自定义算子或特殊梯度)
Python
1
2
3
4
5
class MyFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x): ...
    @staticmethod
    def backward(ctx, grad_out): ...
nn.Module

nn.Module 把“可训练参数 + 前向逻辑 + 子模块拓扑”打包成可组合单元。训练与部署的关键接口是 state_dict():它让参数与 buffer 的保存/加载成为稳定约定,从而把“代码结构”与“权重文件”解耦。

Minimal nn.Module
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn
 
class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, out_dim),
        )
 
    def forward(self, x):
        return self.net(x)
执行与编译

PyTorch 的默认执行是 eager;编译与优化通过 torch.compile 把前向(以及可捕获的反向)分段转换成可优化子图,再由后端生成更高效的执行计划。实践中它最适合用于“训练循环稳定、算子形态固定”的热点路径;高度动态的 Python 控制流与形状多态会降低捕获与复用效果。

动态图

动态图让训练循环具备“逐步可观察性”:每一步前向的张量都可以被打印、断点或插桩;复杂条件分支也能自然表达。这种表达力的代价是:若不做编译或算子融合,小算子密集的模型可能被 Python 调度开销限制。

训练循环

训练循环的工程要点集中在三个地方:梯度清零策略、设备放置与数据搬运、以及异常时可恢复的 checkpoint。下面是一段最小可用的循环骨架。

Minimal training loop (PyTorch)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch.utils.data import DataLoader
 
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MLP(128, 256, 10).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()
 
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
 
model.train()
for step, (x, y) in enumerate(loader):
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
 
    opt.zero_grad(set_to_none=True)
    logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    opt.step()
torch.compile

最常见的接入方式是在训练开始前包一层: model = torch.compile(model)。在大模型场景下,编译往往与 AMP、FlashAttention、FSDP/ZeRO 等一起协作:编译负责降低算子调度与 kernel 生成开销,其它组件负责显存、带宽与通信。

torch.compile (typical)
Python
1
2
model = MLP(128, 256, 10).to(device)
model = torch.compile(model)  # compile after moving to device
训练基础设施

训练基础设施是把“能跑”变成“可长期迭代”的关键层:数据加载决定吞吐上限,AMP 决定显存与速度,checkpoint 决定可恢复性;而 CRF layer 代表一类“训练框架之外的结构化输出层”组件,常见于序列标注。

DataLoader

DataLoader 是 PyTorch 数据管线的核心抽象:负责 batch、shuffle、collate、多进程加载与 pinned memory。它与 Dataset 的分工边界清晰:Dataset 负责“样本怎样被索引/生成”,DataLoader 负责“样本如何被并发读取并组织成 batch”。

参数 作用 经验用法
batch_size 每步样本数 受显存与序列长度影响,常与 gradient accumulation 联动
num_workers 数据加载进程数 CPU 预处理重时提高;过高会导致调度/内存压力
pin_memory 将 batch 放入 pinned memory GPU 训练通常打开,配合 non_blocking=True
collate_fn 自定义 batch 拼接 NLP 里常做 padding/packing;多模态里做对齐与打包
AMP

自动混合精度(Automatic Mixed Precision, AMP)通过在合适的算子上使用 FP16/BF16,并在数值敏感的算子上保留 FP32,换取吞吐与显存的提升。PyTorch 提供 autocast 与 GradScaler 组合来降低接入成本。

AMP (torch.cuda.amp)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
scaler = torch.cuda.amp.GradScaler()
 
for x, y in loader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    opt.zero_grad(set_to_none=True)
 
    with torch.cuda.amp.autocast(dtype=torch.float16):
        logits = model(x)
        loss = loss_fn(logits, y)
 
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
checkpoint

PyTorch 的 checkpoint 约定是保存 state_dict:至少包含模型参数与优化器状态,必要时加上 scheduler、步数与随机种子。训练恢复的核心是把“运行状态”视为数据,而不是只保存权重文件。

Checkpoint save/load (PyTorch)
Python
1
2
3
4
5
6
7
8
9
10
11
ckpt = {
    "step": step,
    "model": model.state_dict(),
    "optimizer": opt.state_dict(),
}
torch.save(ckpt, "ckpt.pt")
 
ckpt = torch.load("ckpt.pt", map_location=device)
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["optimizer"])
step = ckpt["step"]
CRF layer

线性链条件随机场(Linear-chain Conditional Random Field, CRF)是序列标注(Sequence Labeling)里常见的结构化输出层:它不改变 encoder 的表示学习方式,但把标签预测从“逐 token 独立分类”改成“序列级全局最优路径”,以显式建模相邻标签转移约束。

pytorch-crf / torchcrf

pytorch-crf 是常见的第三方 CRF layer 实现,API 以一个 CRF 模块为中心:前向返回 log-likelihood,训练时通常取负作为 loss;解码使用 decode 输出最优标签序列。

Install CRF layer
Shell
1
pip install pytorch-crf

CRF layer (torchcrf) minimal usage
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from torchcrf import CRF
 
num_tags = 5
crf = CRF(num_tags)
 
seq_len, batch = 3, 2
emissions = torch.randn(seq_len, batch, num_tags)          # (seq_len, batch, num_tags)
tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long)  # (seq_len, batch)
 
log_likelihood = crf(emissions, tags)  # summed over batch by default
loss = -log_likelihood
 
best_paths = crf.decode(emissions)     # list[list[int]]
TensorFlow

TensorFlow 的核心执行模型是 eager + graph:默认 eager 便于调试与交互, @tf.function 将 Python 函数 trace 编译成图执行以提升性能与可移植性。训练循环的两条主线分别是:Keras Model.fit 的高层接口,以及基于 tf.GradientTape 的自定义循环。

TensorFlow installation (pip)
Shell
1
2
3
4
5
# CPU-only
pip install -U tensorflow-cpu
 
# Default package (CPU/GPU depends on platform; follow the official pip install guide)
pip install -U tensorflow

TensorFlow quick verify
Python
1
2
3
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices("GPU"))
TensorFlow 计算图与 tf.data

tf.data.Dataset 是 TensorFlow 的输入管线中心:通过 map、 batch、 shuffle、 prefetch 把数据处理做成可并行、可流式的图。生产训练中,输入管线经常成为 GPU 利用率的上限瓶颈,因此应优先把可并行预处理移入 tf.data 体系内。

tf.data input pipeline
Python
1
2
3
4
5
6
7
import tensorflow as tf
 
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.shuffle(10000)
ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(128, drop_remainder=True)
ds = ds.prefetch(tf.data.AUTOTUNE)

Custom training step with GradientTape + tf.function
Python
1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
 
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss = loss_fn(y, logits)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss
Keras 3

Keras 3 是多后端(Multi-backend)深度学习框架:同一份 Keras 代码可以运行在 TensorFlow、JAX、PyTorch 后端上;后端通过 KERAS_BACKEND 环境变量或本地配置文件选择,并且必须在 import Keras 之前确定。对工程而言,这一设计把“模型代码”与“执行后端”分离:可以在同一套高层 API 下切换后端能力,例如在 JAX 上获得更强的编译与 SPMD 体系,或在 PyTorch 上复用既有生态。

Keras 3 installation
Shell
1
pip install -U keras

Select backend before importing keras
Python
1
2
3
4
5
import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" / "torch"
 
import keras
print("backend:", keras.backend.backend())

Keras 3 minimal training
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import keras
from keras import layers
 
model = keras.Sequential([
    layers.Dense(256, activation="gelu"),
    layers.Dense(10),
])
model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=3e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(train_x, train_y, batch_size=128, epochs=3)
Keras 3 的多后端架构

多后端意味着一组实践约束:后端不能在 import 后热切换;部分底层行为(例如随机数、分布式、数值细节)仍由后端决定;性能调优最终仍会落回后端的编译器与 kernel 体系。Keras 3 更适合承担“统一建模接口”,而不是替代后端的性能工程。

模块 定位 典型用法
keras.layers 层与算子组合
Python
1
x = layers.LayerNormalization()(x)
keras.Model 可训练模型单元
Python
1
class MyModel(keras.Model): ...
keras.ops 后端无关算子层(多后端 API)
Python
1
y = keras.ops.matmul(a, b)
keras.optimizers 优化器族
Python
1
opt = keras.optimizers.AdamW(3e-4)
JAX

JAX 的训练编程模型是“纯函数(Pure Function)+ 变换(Transformation)+ 编译(XLA Compilation)”。训练代码通常写成不带副作用的函数,然后用 jit/ grad/ vmap/ pmap 把函数变成可微、可并行、可编译的高性能版本。对工程而言,这意味着:参数与 optimizer state 需要显式放入状态对象;更新步骤常用 jit(value_and_grad(...)) 组织成单一的编译热点。

JAX installation (from official docs)
Shell
1
2
3
4
5
# CPU-only
pip install -U jax
 
# NVIDIA GPU (example, CUDA 13)
pip install -U \"jax[cuda13]\"
函数变换

JAX 的高频核心 API 集中在“函数变换”上。它们本质上都是高阶函数:输入是 Python 函数,输出是新的函数;输出函数具备更强的可微、可并行或可编译特性。

函数 作用 最小用法
jax.jit 把函数编译成 XLA 可执行版本(并缓存编译结果)
Python
1
step = jax.jit(step_fn)
jax.grad 对标量输出函数求梯度(反向模式 AD)
Python
1
g = jax.grad(loss_fn)
jax.value_and_grad 一次性返回 (value, grad),减少重复前向
Python
1
vg = jax.value_and_grad(loss_fn)
jax.vmap 自动向量化,把“单样本函数”提升为“批函数”
Python
1
batched_loss = jax.vmap(loss_fn, in_axes=(None, 0, 0))
jax.pmap 跨多设备 SPMD 并行(常用于数据并行)
Python
1
pstep = jax.pmap(step_fn, axis_name=\"data\")
jit

jit 的工程要点是“可 trace 的纯函数”:Python 侧的动态分支、不可哈希的静态参数、以及频繁变化的输入形状都会导致重新 trace 或重新编译,从而出现性能抖动。训练代码通常会把 step 函数写成固定签名,并把配置项通过静态参数或闭包固定。

grad

grad 适用于标量 loss。多输出或同时需要 aux(例如 metrics)时,通常配合 value_and_grad(has_aux=True)。

vmap

vmap 把 batch 维度“隐式传播”到计算图里,常用于 per-example gradients、对比学习的 pairwise 计算、以及把 Python for-loop 从热点路径挪走。

pmap

pmap 以 named axis 组织跨设备 collectives(如 jax.lax.psum),适合以“每个设备一份副本”的方式做数据并行。更通用的分片(sharding)体系通常落在 pjit/mesh 路线,但训练代码层面仍常见以 pmap 组织最小多卡并行示例。

编译式执行
JAX 执行模型

JAX 训练 step 的典型形态是:“状态输入 → 计算 loss → 求梯度 → 用 optimizer 更新状态”。状态一般用 PyTree 组织(字典、元组、dataclass 等嵌套结构),并作为函数参数显式传递。

Minimal JAX training step (sketch)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import jax
import jax.numpy as jnp
 
def loss_fn(params, batch):
    x, y = batch
    logits = model_apply(params, x)  # pure function
    loss = cross_entropy(logits, y)
    return loss
 
@jax.jit
def step(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss
XLA

XLA 是 JAX 性能的核心来源:它把 Python 层的计算表达编译成设备侧可执行程序,并做算子融合与内存规划。训练工程里,减少 recompile 与控制输入形状稳定通常比微观优化更有效。

Flax

Flax 是基于 JAX 的神经网络库,常见入口是 Linen API:以 Module 表达参数化结构,以 init/apply 显式分离“参数创建”和“前向应用”。训练循环通常围绕 TrainState 把 params 与 optimizer state 统一管理。

Install Flax (typical)
Shell
1
pip install -U flax
对象 定位 典型用法
flax.linen.Module 参数化模块定义
Python
1
2
3
4
5
6
7
8
import flax.linen as nn
 
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.gelu(x)
        return nn.Dense(10)(x)
Module.init 给定 rng 与输入 shape,初始化参数
Python
1
params = model.init(jax.random.key(0), x)["params"]
Module.apply 给定参数执行前向
Python
1
logits = model.apply({"params": params}, x)
flax.training.train_state.TrainState 统一管理 step/params/opt_state
Python
1
2
from flax.training.train_state import TrainState
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
PaddlePaddle

PaddlePaddle 的训练编程覆盖动态图(Dynamic Graph)与静态图(Static Graph)两种执行路径:动态图强调易用与调试;静态图强调编译优化、部署与稳定性能。官方 API 通过 paddle.enable_static() 显式切换到静态图模式。

PaddlePaddle installation (follow official guide)
Shell
1
2
3
4
5
# CPU
pip install -U paddlepaddle
 
# GPU
pip install -U paddlepaddle-gpu

Paddle verify
Python
1
2
3
import paddle
print(paddle.__version__)
paddle.utils.run_check()
执行模式与工具链
动态图

动态图是默认模式,常见训练代码以 paddle.nn.Layer 组织模型,以 paddle.optimizer 更新参数。

静态图

静态图用于追求更强的图级优化与更稳定的部署路径。切换到静态图通常意味着:需要显式构建 program/graph,并使用对应的 executor/engine 执行。实践里更常见的策略是:训练仍以动态图为主,部署阶段再导出静态图或使用官方推理引擎。

产业化工具链

工程体系上,Paddle 生态通常把“训练、推理、部署、端侧”做成一套配套工具链。若目标是快速把模型落到生产场景(OCR、CV、NLP 服务或端侧),这条生态链会显著降低工程摩擦。

Fleet

Fleet 是 Paddle 的分布式训练统一 API:通过 fleet.init 初始化分布式环境,并用 fleet.distributed_optimizer 把普通 optimizer 包装成分布式 optimizer。工程上最常见的是 collective 路线。

Fleet (collective) minimal pattern
Python
1
2
3
4
5
6
7
8
9
import paddle
import paddle.distributed.fleet as fleet
 
fleet.init(is_collective=True)
 
model = paddle.nn.Linear(10, 10)
optimizer = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters())
strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
高层训练与微调框架

这一层关注“训练与推理流程如何被组织成可复用的工程入口”。典型职责包括:模型与 tokenizer 的统一加载接口、训练循环模板、分布式启动封装、参数高效微调(PEFT)挂接、对齐训练(SFT/DPO/GRPO 等)流程化、以及面向特定任务的端到端工具链。

Hugging Face 主线

Transformers / Accelerate / PEFT / TRL 组成了一条高度耦合的工程主线:Transformers 提供模型与任务入口,Accelerate 提供设备与分布式抽象,PEFT 提供适配器挂接,TRL 提供后训练(Post-training)与偏好优化的 Trainer。工程上把它们当作一个整体来装配,比孤立使用更稳定。

Transformers

Transformers 的价值不在“又一个模型库”,而在统一了四类接口:模型(Model)、配置(Config)、预处理器(Tokenizer/Processor)与训练器(Trainer)。它同时覆盖 Encoder-only、Decoder-only 与 Encoder-Decoder 三类架构的微调流程,并把模型权重的下载、缓存与离线加载整理成可编程的 API。

安装
Shell
1
pip install -U transformers accelerate datasets evaluate
核心 API 速查
对象/函数 用途 最小用法(示例)
AutoTokenizer.from_pretrained 加载 tokenizer,并绑定模型同款词表与规范化规则
Python
1
2
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
AutoModel* .from_pretrained 加载模型权重;支持 dtype 与 device_map 等加载策略
Python
1
2
3
4
5
6
from transformers import AutoModelForCausalLM
m = AutoModelForCausalLM.from_pretrained(
  "meta-llama/Llama-2-7b-hf",
  torch_dtype="auto",
  device_map="auto",
)
pipeline 快速推理入口;适合验证模型与任务头是否可用
Python
1
2
3
from transformers import pipeline
clf = pipeline("sentiment-analysis")
clf("hugging face is the best")
Trainer PyTorch 训练循环封装:训练、评估、保存、日志、分布式协同
Python
1
2
3
4
5
6
7
8
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
  output_dir="out",
  per_device_train_batch_size=2,
  num_train_epochs=1,
)
trainer = Trainer(model=m, args=args, train_dataset=ds)
trainer.train()
generate Decoder-only / Seq2Seq 推理生成入口
Python
1
2
3
inputs = tok(["The secret is"], return_tensors="pt").to(m.device)
out_ids = m.generate(**inputs, max_length=64)
print(tok.batch_decode(out_ids, skip_special_tokens=True)[0])
典型微调工作流(Trainer 路线)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from datasets import load_dataset
from transformers import (
  AutoTokenizer,
  AutoModelForSequenceClassification,
  DataCollatorWithPadding,
  Trainer,
  TrainingArguments,
)
 
ds = load_dataset("glue", "sst2")
tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
m = AutoModelForSequenceClassification.from_pretrained(
  "distilbert-base-uncased", num_labels=2
)
 
def tokenize(batch):
  return tok(batch["sentence"], truncation=True)
 
ds = ds.map(tokenize, batched=True)
collator = DataCollatorWithPadding(tokenizer=tok)
 
args = TrainingArguments(
  output_dir="out_sst2",
  evaluation_strategy="steps",
  eval_steps=200,
  save_steps=200,
  logging_steps=50,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  num_train_epochs=1,
)
 
trainer = Trainer(
  model=m,
  args=args,
  train_dataset=ds["train"],
  eval_dataset=ds["validation"],
  data_collator=collator,
  tokenizer=tok,
)
trainer.train()
Accelerate

Accelerate 把“同一份 PyTorch 训练代码”映射到单卡、多卡、TPU、DeepSpeed、FSDP 等不同执行环境。它有两条入口:一条是 CLI( accelerate config/ accelerate launch)负责启动与进程编排;另一条是 Accelerator 类负责在代码层包裹模型、优化器、dataloader 与 backward。

安装与 CLI 快速落地
Shell
1
2
3
4
5
6
7
8
9
10
pip install -U accelerate
 
# 生成运行环境配置(会写入 default_config.yaml)
accelerate config
 
# 验证分布式环境是否可用
accelerate test
 
# 启动训练脚本(会按配置自动选择 DDP/FSDP/DeepSpeed 等后端)
accelerate launch train.py --arg1 v1
核心 API 速查
对象/函数 用途 最小用法(示例)
Accelerator() 统一管理 device、分布式通信、混合精度与梯度累积
Python
1
2
3
from accelerate import Accelerator
acc = Accelerator()
device = acc.device
prepare() 包裹 model/optimizer/dataloader/scheduler
Python
1
model, opt, dl = acc.prepare(model, opt, dl)
backward() 替代 loss.backward(),适配不同后端
Python
1
acc.backward(loss)
gather_for_metrics() 评估阶段收集分布式预测,避免只看本 rank
Python
1
preds, labels = acc.gather_for_metrics((preds, labels))
PEFT

PEFT(Parameter-Efficient Fine-Tuning)把“修改基座模型权重”改为“在基座旁边挂接可训练的低秩/提示/适配器参数”。在工程上,它主要解决三个问题:一是显存与训练成本,二是多任务/多版本适配器的管理,三是把适配器作为交付物(checkpoint)而不是整个基座。

安装
Shell
1
pip install -U peft
核心 API 速查
对象/函数 用途 最小用法(示例)
LoraConfig 定义 LoRA 的 rank、alpha、dropout 与 target modules
Python
1
2
3
4
5
6
from peft import LoraConfig, TaskType
cfg = LoraConfig(
  task_type=TaskType.CAUSAL_LM,
  r=16, lora_alpha=32, lora_dropout=0.05,
  target_modules=["q_proj","v_proj"],
)
get_peft_model 把 PEFT 配置挂到 Transformers 模型上
Python
1
2
from peft import get_peft_model
m = get_peft_model(m, cfg)
PeftModel.from_pretrained 在基座上加载已训练好的 adapter
Python
1
2
from peft import PeftModel
m = PeftModel.from_pretrained(base, "my_adapter_dir")
merge_and_unload 把 adapter 合并回基座权重,导出纯模型
Python
1
m = m.merge_and_unload()
Transformers + PEFT 微调骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
 
base = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto", device_map="auto")
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, lora_dropout=0.05,
                 target_modules=["q_proj","k_proj","v_proj","o_proj"])
model = get_peft_model(base, cfg)
 
args = TrainingArguments(output_dir="out_lora", per_device_train_batch_size=1, gradient_accumulation_steps=16)
trainer = Trainer(model=model, args=args, train_dataset=ds)
trainer.train()
 
model.save_pretrained("adapter_out")    # 只保存 adapter
tok.save_pretrained("adapter_out")
TRL

TRL 把后训练(Post-training)中常见的方法流程化:SFT、DPO、GRPO、Reward Modeling 等。它的核心接口是多种 Trainer;工程上需要关注两件事:一是数据格式(尤其是偏好对数据结构),二是运行时集成(与 Transformers/PEFT/DeepSpeed/vLLM 的协作)。

安装
Shell
1
pip install -U trl
Trainer 家族(按方法类型)
Trainer 方法类型 典型用途
SFTTrainer Offline 监督微调(指令数据、格式对齐、领域适配)
DPOTrainer Offline 偏好优化(pairwise preference)
GRPOTrainer Online 基于组相对优势的策略优化(常用于 CoT/RL 后训练)
RewardTrainer Reward modeling 训练奖励模型,为在线方法或 reranker 提供打分器
一个最小 SFTTrainer 骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from trl import SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
 
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto", device_map="auto")
 
args = TrainingArguments(
  output_dir="out_sft",
  per_device_train_batch_size=1,
  gradient_accumulation_steps=16,
  logging_steps=10,
)
 
trainer = SFTTrainer(
  model=model,
  args=args,
  train_dataset=ds,
  tokenizer=tok,
)
trainer.train()
Sentence Transformers

Sentence Transformers 把“可用的 embedding 训练与推理”包装成稳定的 Python API。它既提供 SentenceTransformer(Bi-Encoder)路线,也提供 CrossEncoder(Cross-Encoder)路线。工程上常见分工是:Bi-Encoder 做向量召回与批量 embedding 生成;Cross-Encoder 做 reranking 或匹配打分。

安装
Shell
1
pip install -U sentence-transformers
核心 API 速查
对象/函数 用途 最小用法(示例)
SentenceTransformer 加载/训练 embedding 模型(Bi-Encoder)
Python
1
2
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
encode 将文本批量编码为向量
Python
1
emb = model.encode(["hello", "world"], normalize_embeddings=True)
fit 训练入口(损失、dataloader、评估器等由库组织)
Python
1
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1)
大模型微调工作台
Unsloth

Unsloth 的工程定位更像“本地训练与本地部署的工作台”,覆盖 UI 与代码两条路径:Unsloth Studio 偏一键安装与可视化流程;Unsloth Core 偏脚本化训练、导出与运行。它常被用于消费级显卡上的微调与导出链路(例如导出 GGUF、部署本地 API)。

安装(Studio)
Shell
1
2
# macOS / Linux / WSL
curl -fsSL https://unsloth.ai/install.sh | sh

1
2
# Windows PowerShell
irm https://unsloth.ai/install.ps1 | iex
典型工作流(面向本地训练与导出)

在工程链路上,Unsloth 经常作为“训练 + 导出 + 本地运行”的一体化入口:训练阶段对接 Transformers/PEFT/TRL 的微调流程,交付物可以是 adapter 或导出的 GGUF/16-bit 权重;推理阶段可以走本地 API endpoint,或导入到其他推理栈。

ModelScope

ModelScope(魔搭)提供类似“模型即服务(MaaS)”的统一 SDK:既能用 pipeline 做推理,也能用 Trainer 抽象做微调与评估。它的工程价值在于中文生态与多领域模型的统一入口,以及与其模型/数据 Hub 的整合。

安装
Shell
1
pip install -U modelscope
推理 pipeline 入口
Python
1
2
3
4
5
6
7
from modelscope.pipelines import pipeline
 
word_segmentation = pipeline(
  "word-segmentation",
  model="damo/nlp_structbert_word-segmentation_chinese-base",
)
print(word_segmentation("今天天气不错,适合出去游玩"))
Trainer 入口(微调骨架)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
 
train_dataset = MsDataset.load("chinese-poetry-collection", split="train").remap_columns({"text1": "src_txt"})
eval_dataset = MsDataset.load("chinese-poetry-collection", split="test").remap_columns({"text1": "src_txt"})
 
kwargs = dict(
  model="damo/nlp_gpt3_text-generation_1.3B",
  train_dataset=train_dataset,
  eval_dataset=eval_dataset,
  max_epochs=10,
  work_dir="./gpt3_poetry",
)
 
trainer = build_trainer(name=Trainers.gpt3_trainer, default_args=kwargs)
trainer.train()
任务特定训练框架
span-based NER 框架

span-based NER 把实体识别从“逐 token 标注”转为“枚举 span 并分类/打分”。这类框架通常天然更适合零样本/少样本标签扩展,并且更容易做 CPU 友好推理与轻量服务化。

GLiNER
安装
Shell
1
pip install -U gliner
推理 API(predict_entities)
Python
1
2
3
4
5
6
7
8
from gliner import GLiNER
 
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
text = "John works at Google. Paris is in France."
labels = ["person", "organization", "location"]
 
entities = model.predict_entities(text, labels, threshold=0.5)
print(entities)
服务化(Ray Serve)
Shell
1
2
pip install -U "gliner[serve]"
python -m gliner.serve --model urchade/gliner_small-v2.1 --enable-flashdeberta
训练流程组织框架
PyTorch Lightning

Lightning 把训练工程样板(device、logger、checkpoint、DDP/FSDP、回调)抽象进 Trainer,把研究代码收敛到 LightningModule。它在大团队或多项目复用场景里很常见:训练方式统一,扩展点集中在 callbacks、loggers 与策略配置。

安装
Shell
1
pip install -U lightning
最小训练骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
 
class LitModel(L.LightningModule):
  def __init__(self):
    super().__init__()
    self.net = nn.Linear(10, 2)
 
  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self.net(x)
    loss = F.cross_entropy(logits, y)
    self.log("train_loss", loss)
    return loss
 
  def configure_optimizers(self):
    return torch.optim.AdamW(self.parameters(), lr=3e-4)
 
trainer = L.Trainer(max_epochs=1, accelerator="auto", devices="auto")
trainer.fit(LitModel(), train_dataloaders=train_loader)
MMEngine

MMEngine 是 OpenMMLab 全系仓库的训练引擎。它把“训练引擎(Runner)+ 配置系统 + 日志/可视化后端 + 钩子(Hook)机制”做成通用底座,并对接 DeepSpeed/FSDP 等大模型训练框架。它适合需要强配置化、统一运行入口、以及跨多个 CV 任务仓库复用的团队。

安装(openmim)
Shell
1
2
pip install -U openmim
mim install mmengine
最小示例(BaseModel + Runner 思路)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
 
class MMResNet50(BaseModel):
  def __init__(self):
    super().__init__()
    self.resnet = torchvision.models.resnet50()
 
  def forward(self, imgs, labels, mode):
    x = self.resnet(imgs)
    if mode == "loss":
      return {"loss": F.cross_entropy(x, labels)}
    elif mode == "predict":
      return x, labels
OpenMMLab

OpenMMLab 是一个以 MMEngine 为训练底座的 CV 开源生态(检测、分割、姿态、生成等)。从工程角度看,它更像“标准化研究代码基座 + 大量可复用算法实现”。如果团队需要长期维护多类 CV 模型,并追求统一配置、统一日志、统一 checkpoint 与统一评测流程,OpenMMLab 的价值通常高于“从零搭一个训练框架”。

经典机器学习工程框架

这类库不属于大模型主线,但在特征工程、表格任务、以及小模型 baseline 中仍然高频存在。它们的接口几乎都围绕 fit/ predict/ predict_proba,工程上更强调数据清洗与特征一致性。

训练脚本的基本组成

训练脚本的价值在于把训练过程工程化:数据输入稳定、训练状态可恢复、指标可观测、实验可对比、产物可追溯。一个可维护的训练脚本通常围绕四件事组织:训练循环、状态管理、配置入口、可观测性与评估。

最小训练循环

训练循环的目标是把“损失函数关于参数的梯度”转化为“参数更新”。在 PyTorch 中,这条链路可写成:前向得到 loss,反向计算梯度,优化器 step 更新参数。工程上再叠加三类必需机制:学习率调度、数值稳定/效率策略(累积、混合精度、梯度裁剪)、训练状态的保存与恢复。

训练核心对象
模型

模型在脚本里承担两种职责:定义参数化映射,以及提供可复现的前向路径。训练脚本中最容易被忽略的细节是模式切换与设备放置:训练时必须 model.train(),评估时必须 model.eval();参数与输入必须在同一设备与兼容精度上。

Python
1
2
3
4
5
6
7
8
device = "cuda"  # or "cpu"
model = MyModel(...)
model.to(device)
 
for batch in train_loader:
    model.train()
    x, y = batch["x"].to(device), batch["y"].to(device)
    logits = model(x)
损失

损失函数是训练脚本的“唯一可优化目标”。工程上需要把损失拆成两层:第一层是数学定义(例如 CE/BCE/MSE);第二层是数据与张量形状约定(logits vs probabilities、label dtype、ignore_index、padding mask)。脚本里应当显式处理这层约定,避免模型输出与 loss 之间隐含转换。

Python
1
2
3
4
import torch.nn.functional as F
 
logits = model(x)                 # [B, C]
loss = F.cross_entropy(logits, y) # y: [B], dtype=torch.long
优化器

优化器把梯度转成参数更新。训练脚本里,优化器的“正确性”主要取决于三件事:参数组(parameter groups)是否分对、 zero_grad 是否用 set_to_none=True 清零、以及 step 的节奏是否与梯度累积/混合精度一致。

Python
1
2
3
4
5
6
7
import torch
 
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.1)
 
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
常用优化器 安装 典型入口 脚本要点
torch.optim.AdamW 随 PyTorch 提供 torch.optim.AdamW(...) LLM/Transformer 微调的默认选择之一;建议显式设置 weight_decay;必要时拆 parameter groups 让 bias/Norm 走 0 weight_decay。
torch.optim.SGD 随 PyTorch 提供 torch.optim.SGD(...) 常用于 CNN/视觉训练;注意 momentum、nesterov 与 weight_decay 的组合。
自定义/函数式优化器 依赖实现 optimizer.step() 若使用函数式 API(functional optimizers),需要把 grad_scale/found_inf 等 AMP 信息正确传递给优化器。
scheduler

学习率调度器的工程关键在于 step 的触发时机;调度器选型通常是第二顺位。常见两类:

  • epoch 级 step:每个 epoch 结束后调用一次。
  • step 级 step:每个 optimizer update 后调用一次(常见于 warmup、OneCycle 等)。

在 PyTorch 中,调度器通常在 optimizer.step() 之后调用,避免跳过初始学习率。恢复训练时也应保存/加载 scheduler 的 state。

Python
1
2
3
4
5
6
7
8
9
from torch.optim.lr_scheduler import CosineAnnealingLR
 
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
 
for epoch in range(num_epochs):
    train_one_epoch(...)
    optimizer.step()
    scheduler.step()
scheduler step 粒度 典型用途 注意点
StepLR epoch 分段衰减 与里程碑 epoch 对齐,通常在 epoch 末调用。
CosineAnnealingLR epoch 或 step 平滑衰减 需要明确 T_max 的含义(epoch 数或 step 数)。
OneCycleLR step warmup + 衰减的一体化策略 必须提供 total_steps 或 epochs + steps_per_epoch;在每次 optimizer update 后 step。
ReduceLROnPlateau eval 事件驱动 指标不提升就降 LR step 时需要传入监控指标(例如 val_loss)。
稳定性与效率机制
gradient accumulation

梯度累积(Gradient Accumulation)用多次 backward() 模拟更大的 batch:每个 micro-batch 只反向,不更新;累积到指定步数后再统一 optimizer.step()。工程上需要把 loss 除以累积步数,保证梯度尺度不被放大。

Python
1
2
3
4
5
6
7
8
9
10
11
grad_accum_steps = 8
optimizer.zero_grad(set_to_none=True)
 
for step, batch in enumerate(train_loader):
    loss = compute_loss(batch)
    loss = loss / grad_accum_steps
    loss.backward()
 
    if (step + 1) % grad_accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
mixed precision

混合精度(Automatic Mixed Precision, AMP)在前向与反向中对不同算子选择不同精度,提升吞吐并降低显存占用。PyTorch 推荐使用 torch.autocast 与 torch.amp.GradScaler 组合;旧的 torch.cuda.amp.autocast 已被标注为弃用入口,脚本应迁移到 torch.amp.autocast("cuda") 风格。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
 
scaler = torch.amp.GradScaler("cuda")
optimizer.zero_grad(set_to_none=True)
 
for batch in train_loader:
    with torch.amp.autocast("cuda", dtype=torch.float16):
        loss = compute_loss(batch)
 
    scaler.scale(loss).backward()
 
    # 如果要做梯度裁剪,需要先 unscale
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
clipping

梯度裁剪(Gradient Clipping)用于抑制梯度爆炸与异常尖峰更新。脚本里常用两种方式:按范数裁剪与按值裁剪。混合精度场景下,裁剪通常发生在 scaler.unscale_(optimizer) 之后、 scaler.step(optimizer) 之前。

函数 用途 典型用法
torch.nn.utils.clip_grad_norm_ 按整体范数裁剪梯度
Python
1
2
3
4
5
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
    norm_type=2.0,
)
torch.nn.utils.clip_grad_value_ 按绝对值范围裁剪梯度
Python
1
2
3
4
torch.nn.utils.clip_grad_value_(
    model.parameters(),
    clip_value=0.5,
)
训练状态管理
checkpoint

checkpoint 的工程目标是可恢复性与可追溯性。推荐保存 state_dict,而不是直接 pickle 整个模型对象。一个可恢复 checkpoint 至少包含:

  • model.state_dict()
  • optimizer.state_dict()
  • scheduler.state_dict()(如果使用)
  • GradScaler.state_dict()(如果使用 AMP)
  • 当前 epoch、global_step、最佳指标与早停计数器
  • 必要时保存 RNG 状态(CPU/CUDA)以便复现实验
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
 
ckpt = {
    "epoch": epoch,
    "global_step": global_step,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict() if scheduler else None,
    "scaler": scaler.state_dict() if scaler else None,
    "best_metric": best_metric,
    "patience": patience_counter,
    "rng_state": torch.get_rng_state(),
}
if torch.cuda.is_available():
    ckpt["cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
 
torch.save(ckpt, ckpt_path)
resume

恢复训练(resume)要求脚本严格区分两种加载:

  • 只加载权重(用于推理或 warmstart):只读 model 的 state_dict。
  • 恢复训练:除了 model,还要恢复 optimizer/scheduler/scaler 与计数器。

跨设备恢复时,应使用 map_location 控制张量落点。对于大型权重文件,PyTorch 提供了 mmap 相关建议与加载技巧,可用于降低峰值内存。

Python
1
2
3
4
5
6
7
8
9
10
11
12
ckpt = torch.load(ckpt_path, map_location="cpu")
 
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
if scheduler and ckpt.get("scheduler") is not None:
    scheduler.load_state_dict(ckpt["scheduler"])
if scaler and ckpt.get("scaler") is not None:
    scaler.load_state_dict(ckpt["scaler"])
 
start_epoch = ckpt["epoch"] + 1
global_step = ckpt["global_step"]
best_metric = ckpt.get("best_metric", None)
early stopping

早停(Early Stopping)的正确写法是“基于业务真正关心的指标触发”。分类任务通常监控 F1/Accuracy;生成任务通常监控 ROUGE/BLEU 或下游任务指标;有些任务 loss 的上升代表校准变差但决策指标仍提升,因此脚本应把 monitor 指标显式参数化,而不是写死为 val_loss。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
patience = 3
best = None
bad_epochs = 0
 
for epoch in range(num_epochs):
    train_one_epoch(...)
    metric = evaluate(...)
 
    if best is None or metric > best:
        best = metric
        bad_epochs = 0
        save_best_checkpoint(...)
    else:
        bad_epochs += 1
        if bad_epochs >= patience:
            break
配置系统

训练脚本的配置系统需要解决两类问题:参数入口(CLI/环境变量)与配置结构(分层配置、默认值、校验)。通用格式(例如 YAML)只承担“配置文件承载体”的角色,真正的工程收益来自:覆盖语法、层级合并、以及把配置变成强类型对象。

命令行构建
argparse
Python
1
2
3
4
5
6
import argparse
 
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=8)
args = parser.parse_args()
Click
Python
1
2
3
4
5
6
7
8
9
10
import click
 
@click.command()
@click.option("--lr", type=float, default=2e-4)
@click.option("--batch-size", type=int, default=8)
def main(lr, batch_size):
    ...
 
if __name__ == "__main__":
    main()
Typer
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from typing import Annotated
import typer
 
app = typer.Typer()
 
@app.command()
def train(
    lr: Annotated[float, typer.Option()] = 2e-4,
    batch_size: Annotated[int, typer.Option()] = 8,
):
    ...
 
if __name__ == "__main__":
    app()
配置与模式校验
Hydra

Hydra 的核心价值是“分层配置 + 命令行覆盖 + 默认输出目录管理”。训练脚本里常把超参数、数据路径、模型结构与运行参数拆分成多个 config group,再通过 overrides 组合出一次实验。

Python
1
2
3
4
5
6
7
8
9
10
import hydra
from omegaconf import DictConfig
 
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # cfg.lr, cfg.train.batch_size, cfg.model.name, ...
    ...
 
if __name__ == "__main__":
    main()
OmegaConf
Python
1
2
3
4
5
from omegaconf import OmegaConf
 
cfg = OmegaConf.load("conf/config.yaml")
cfg = OmegaConf.merge(cfg, {"train": {"batch_size": 8}})
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
Pydantic

Pydantic 的价值是把“松散字典配置”收敛为“可验证的强类型配置对象”,在脚本启动阶段就能把拼写错误与类型错误拒之门外。

Python
1
2
3
4
5
6
7
from pydantic import BaseModel, Field
 
class TrainConfig(BaseModel):
    lr: float = Field(default=2e-4, ge=0.0)
    batch_size: int = Field(default=8, ge=1)
 
cfg = TrainConfig(lr=2e-4, batch_size=8)
日志与可视化
TensorBoard

TensorBoard 的工程用法是“在训练循环中持续写入事件文件”,再用 TensorBoard UI 查询。PyTorch 提供 torch.utils.tensorboard.SummaryWriter 作为主入口。

Shell
1
2
pip install tensorboard
tensorboard --logdir runs

Python
1
2
3
4
5
6
from torch.utils.tensorboard import SummaryWriter
 
writer = SummaryWriter(log_dir="runs/exp-001")
writer.add_scalar("train/loss", loss.item(), global_step)
writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_step)
writer.flush()
实验管理与跟踪
Weights & Biases

W&B 的训练脚本集成围绕三个动作:init 建立 run,log 写入指标与超参,finish 结束 run。离线环境可使用 offline 模式把日志落盘后再同步。

Shell
1
2
pip install wandb
wandb login

Python
1
2
3
4
5
6
import wandb
 
run = wandb.init(project="exp", config={"lr": 2e-4, "batch_size": 8})
for step in range(100):
    wandb.log({"train/loss": float(loss), "train/lr": optimizer.param_groups[0]["lr"]}, step=step)
run.finish()
MLflow

MLflow Tracking 的核心是 run:在 run 上记录 params、metrics 与 artifacts。最小闭环是:设置 experiment,启动 run,上报指标,必要时启动本地 tracking server 查看 UI。

Shell
1
2
pip install mlflow
mlflow server --port 5000

Python
1
2
3
4
5
6
7
import mlflow
 
mlflow.set_experiment("exp")
with mlflow.start_run():
    mlflow.log_params({"lr": 2e-4, "batch_size": 8})
    mlflow.log_metric("train_loss", float(loss), step=global_step)
    mlflow.log_artifact("checkpoints/best.pt")
LLM 可观测性
Langfuse

Langfuse 在训练脚本中的典型价值是把“训练过程中的 LLM 调用、数据生成、评测调用”以 trace/span/generation 的方式串成可查询的链路,并与指标平台形成分工:W&B/MLflow 负责 run 级指标与产物,Langfuse 负责调用链与上下文。短生命周期脚本要显式 flush 或 shutdown,确保事件被发送。

Shell
1
pip install langfuse

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os
from langfuse import get_client
 
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-..."
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-..."
os.environ["LANGFUSE_BASE_URL"] = "https://cloud.langfuse.com"
 
langfuse = get_client()
 
with langfuse.start_as_current_observation(as_type="span", name="train-step") as span:
    # 训练逻辑
    span.update(metadata={"global_step": global_step})
 
    with langfuse.start_as_current_observation(as_type="generation", name="synth-data", model="gpt-4.1") as gen:
        # LLM 生成数据/评测逻辑
        gen.update(output="...")
 
langfuse.flush()
评估与基准
通用评估指标

训练脚本的评估模块需要满足两个工程要求:可重复(同一 checkpoint 评估一致)与可对比(同一指标口径跨 run 可比较)。分类任务的 Accuracy/F1、检索任务的 Recall@K/NDCG、生成任务的 ROUGE/BLEU 与任务自定义评分,应在脚本里拆成独立的 evaluate 函数,避免训练循环与评估逻辑互相污染。

中文评估指标
rouge-chinese

rouge-chinese 提供中文场景的 ROUGE 计算实现,针对中文标点分句与 ROUGE-L 内存占用做了工程优化。训练脚本中通常把它放在验证阶段,用于摘要、生成式问答等任务的离线评估。

Shell
1
pip install rouge-chinese

Python
1
2
3
4
5
6
7
8
from rouge_chinese import Rouge
 
rouge = Rouge()
hyps = ["模型生成的摘要。"]
refs = ["参考摘要。"]
 
scores = rouge.get_scores(hyps, refs, avg=True)
# scores["rouge-1"]["f"], scores["rouge-2"]["f"], scores["rouge-l"]["f"]
数据与标注资产管理
DVC

DVC 把“数据/模型产物”从 Git 中分离出来,同时保留可追溯版本。训练脚本常配合 DVC 使用两条链路:数据版本管理(dvc add/pull/push)与流水线复现(dvc.yaml + dvc repro)。

Shell
1
2
3
4
5
pip install dvc
dvc init
dvc add data/train.jsonl
git add data/train.jsonl.dvc data/.gitignore
git commit -m "track dataset with dvc"
Label Studio

Label Studio 是标注平台。训练脚本侧通常把它当作“数据生成与质量控制”的外部系统:标注阶段产出数据,训练阶段只消费导出的标注结果。最小可运行入口是安装并启动服务。

Shell
1
2
pip install label-studio
label-studio start
分布式训练与硬件加速组件

分布式训练与硬件加速的工程工作围绕四个入口展开:进程如何启动、通信如何建立、显存如何被切分与回收、关键算子是否落在高性能 kernel。本节以“能直接跑起来”的安装、启动、API、配置与部署约束为中心,覆盖 PyTorch distributed(DP/DDP/FSDP/torchrun)、DeepSpeed(ZeRO)、Megatron-LM/Megatron Core、CUDA/cuDNN/NCCL、Triton、FlashAttention、flash-linear-attention(fla)、xFormers、bitsandbytes,以及数值精度与重算策略。

设备与并行基础
计算设备

训练代码在多卡场景里通常遵循“一进程一 GPU”的约定:每个进程只绑定一个 GPU,并通过进程间通信完成梯度同步或参数分片。这一约定直接对应 torchrun/DDP/FSDP/DeepSpeed 的默认启动方式,也决定了日志、随机数与数据采样需要按 rank 做隔离。

并行策略

工程上最常用的并行拆分有两类:

  • 数据并行(Data Parallel, DP):每张卡持有一份模型副本,各自处理不同 batch,然后同步梯度。
  • 模型并行(Model Parallel):把一个模型拆到多张卡上。常见细分是张量并行(Tensor Parallel, TP)与流水并行(Pipeline Parallel, PP)。

对于 LLM 预训练,TP/PP 往往与数据并行同时存在;对多数微调任务,数据并行 + 参数高效微调(LoRA/QLoRA)是更常见的起点。

PyTorch 分布式主线
安装与环境校验

PyTorch 分布式训练的最低要求是:PyTorch 构建启用了 distributed,并且 GPU 通信后端可用(NVIDIA 场景通常是 NCCL)。工程上先做三类校验:CUDA 运行时可用性、distributed 模块可用性、以及当前进程可否正确枚举到 GPU。

Python
1
2
3
4
5
6
7
8
9
import torch
import torch.distributed as dist
 
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("torch cuda:", torch.version.cuda)
print("distributed available:", dist.is_available())
if torch.cuda.is_available():
    print("gpu0:", torch.cuda.get_device_name(0))
DataParallel (DP)

DataParallel 是单进程、多 GPU 的封装( torch.nn.DataParallel)。它的工程局限很明确:单进程会成为瓶颈、参数分散与通信控制不够细,且与现代分布式生态(torchrun/elastic/DCP)不在同一路线上。实际工程里更常把 DP 当作“快速验证多卡可跑”的临时方案,正式训练一般直接用 DDP 或 FSDP。

Python
1
2
3
4
5
import torch
import torch.nn as nn
 
model = nn.Linear(1024, 1024).cuda()
model = nn.DataParallel(model)  # 单进程,多卡
DDP

DDP(DistributedDataParallel)是 PyTorch 数据并行主线:多进程各自持有模型副本,反向传播时进行梯度 AllReduce。工程上,DDP 的三个稳定性入口是:进程组初始化、每个进程绑定本地 GPU、以及数据采样在 rank 之间的正确切分。

启动方式(torchrun)
Shell
1
2
# 单机 8 卡(每个进程绑定 1 张 GPU)
torchrun --standalone --nproc-per-node=8 train.py

Shell
1
2
3
4
5
6
# 多机(示例:2 台机器,每台 8 卡)
# node0:
torchrun --nnodes=2 --node-rank=0 --nproc-per-node=8 --master_addr=$MASTER_ADDR --master_port=29500 train.py
 
# node1:
torchrun --nnodes=2 --node-rank=1 --nproc-per-node=8 --master_addr=$MASTER_ADDR --master_port=29500 train.py
最小 DDP 训练骨架
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import argparse
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local-rank", "--local_rank", type=int, default=None)
    _ = parser.parse_args()
 
    dist.init_process_group(backend="nccl")  # NVIDIA GPU 通常用 NCCL
 
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
 
    model = torch.nn.Linear(1024, 1024).cuda()
    model = DDP(model, device_ids=[local_rank])
 
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    x = torch.randn(8, 1024, device="cuda")
    y = model(x).sum()
    y.backward()
    opt.step()
    opt.zero_grad(set_to_none=True)
 
    dist.destroy_process_group()
 
if __name__ == "__main__":
    main()
数据切分(DistributedSampler)

DDP 下数据切分通常使用 torch.utils.data.distributed.DistributedSampler。它的工程意义是:每个 rank 只看见数据集的一个分片,并且在每个 epoch 以相同随机种子但不同偏移做 shuffle。训练循环里需要在每个 epoch 调用 sampler.set_epoch(epoch),否则多卡的 shuffle 行为容易退化为“每个 epoch 都是同一切分”。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
 
dataset = ...
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
loader = DataLoader(dataset, batch_size=bs, sampler=sampler, num_workers=4, pin_memory=True)
 
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)
    for batch in loader:
        ...
常见约束与调参入口
  • backend 选择:NVIDIA GPU 通常选 NCCL;CPU 场景通常用 Gloo(性能不同)。
  • find_unused_parameters:动态图/分支模型可能需要,但会引入开销;结构固定的训练尽量避免。
  • 梯度桶(bucket):DDP 会把梯度聚合成 bucket 做 AllReduce,bucket 大小与拓扑会影响吞吐与尾延迟。
FSDP

FSDP(Fully Sharded Data Parallel)把参数、梯度与优化器状态按 data-parallel rank 做分片,以显著降低“模型状态显存”。实践上分两条 API 主线:FSDP2(当前推荐)与 FSDP1(传统 wrapper 形态)。两者共享一个工程事实:优化器应在模型被分片之后创建,因为参数对象会被重映射。

最小 FSDP2 骨架(fully_shard)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, FSDPModule
 
def main():
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
 
    model = Transformer()  # 伪代码:包含 model.layers
    for layer in model.layers:
        fully_shard(layer)
    fully_shard(model)
    assert isinstance(model, FSDPModule)
 
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    ...
最小 FSDP1 骨架(FullyShardedDataParallel)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 
def main():
    dist.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
 
    model = torch.nn.Linear(1024, 1024).cuda()
    model = FSDP(model)
 
    # 注意:optimizer 要在 FSDP wrap 之后创建
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
 
    x = torch.randn(8, 1024, device="cuda")
    loss = model(x).sum()
    loss.backward()
    opt.step()
    opt.zero_grad(set_to_none=True)
 
    dist.destroy_process_group()
 
if __name__ == "__main__":
    main()
auto wrap 与分片策略

分片边界直接决定通信形态:对极小模块做分片会导致频繁 all-gather/reduce-scatter,吞吐明显下降。实践上常把分片边界放在较大的 Transformer block 级别,并在框架侧配合 activation checkpointing 来降低激活占用。

torchrun

torchrun 是 PyTorch 提供的分布式启动器,等价于 python -m torch.distributed.run。它负责为每个进程注入 rank/world_size/local_rank 等环境变量,并管理多机训练的 rendezvous。对于 GPU 训练,torchrun 的默认模型是“每进程一 GPU”。

常用启动参数
参数 含义 示例
--nproc-per-node 每台机器启动的进程数(GPU 训练通常等于每机 GPU 数)
Shell
1
torchrun --standalone --nproc-per-node=8 train.py
--nnodes / --node-rank 多机训练的节点数量与当前节点序号
Shell
1
torchrun --nnodes=2 --node-rank=0 --nproc-per-node=8 --master_addr=$MASTER --master_port=29500 train.py
--standalone 单机训练使用本地 rendezvous,省去显式配置
Shell
1
torchrun --standalone --nproc-per-node=8 train.py
launcher

在更大规模场景里,torchrun 之外通常还会叠一层集群 launcher(例如 Slurm 的 srun,或 K8s job controller),负责资源分配与节点编排。工程边界一般是:launcher 负责“分配哪些机器/卡”,torchrun 负责“每台机器上起哪些进程并建立通信”。

大模型分布式系统
DeepSpeed

DeepSpeed 把“大模型训练需要的显存管理与并行策略”产品化:通过 deepspeed.initialize 与一个配置文件,让训练脚本在不重写大量底层逻辑的情况下获得 ZeRO、offload、优化器与调度能力。它的关键工程入口是:安装、配置文件、启动命令与与现有训练循环的接入点。

安装
Shell
1
pip install deepspeed
启动
Shell
1
2
# 典型用法:用 deepspeed 作为 launcher
deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json
最小配置(ds_config.json)

DeepSpeed 的工程事实是“配置驱动”:显存分片、offload、通信重叠与一些优化器实现由 JSON 配置决定。下列示例是可落地的起点,常见改动集中在 ZeRO stage 与 offload 选项。

ds_config.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
ZeRO-3 / Offload 配置骨架
ds_config_zero3.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "offload_param": { "device": "cpu", "pin_memory": true },
    "offload_optimizer": { "device": "cpu", "pin_memory": true }
  }
}
Python 接入(deepspeed.initialize)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import deepspeed
 
model = ...
params = model.parameters()
 
engine, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model,
    model_parameters=params,
    config="ds_config.json",
)
 
for batch in dataloader:
    loss = engine(batch)
    engine.backward(loss)
    engine.step()
Megatron

Megatron-LM 是面向 Transformer 预训练的参考实现体系,内置张量并行(TP)、流水并行(PP)、以及与 NVIDIA 生态加速库的集成。它更像“训练系统工程模板”:直接复用仓库里的预训练脚本与并行参数,然后在其上叠加数据、模型结构与实验约束。

安装入口(推荐容器路径)

Megatron-LM/Megatron Core 对 CUDA、PyTorch、Transformer Engine、通信库的版本组合敏感。工程上通常优先使用 NGC 的 PyTorch 容器作为基线,再在容器内安装/开发 Megatron 相关代码,减少 ABI 与编译链不一致带来的问题。

Shell
1
2
3
4
5
6
docker run --runtime=nvidia --gpus all -it --rm \
  -v /path/to/megatron:/workspace/megatron \
  -v /path/to/dataset:/workspace/dataset \
  -v /path/to/checkpoints:/workspace/checkpoints \
  -e PIP_CONSTRAINT= \
  nvcr.io/nvidia/pytorch:25.04-py3
启动模式(示意)

Megatron-LM 的典型启动方式仍是 torchrun,但会显式配置 TP/PP 规模,并把全局 batch 拆成 micro-batch + accumulation。下列命令展示“最少参数框架”,具体模型/数据参数由脚本与配置决定。

Shell
1
2
3
4
5
6
7
8
torchrun --nproc-per-node=8 pretrain_gpt.py \
  --tensor-model-parallel-size 2 \
  --pipeline-model-parallel-size 2 \
  --micro-batch-size 1 \
  --global-batch-size 128 \
  --sequence-length 4096 \
  --train-iters 1000 \
  ...
Megatron Core

Megatron Core 是可组合库形态:把训练大型 Transformer 所需的关键模块与系统优化能力封装成 API,供自定义训练框架调用。它提供 pip 安装与示例训练循环,工程上适合“需要 Megatron 的并行与算子能力,但不想完全使用 Megatron-LM 全栈脚本”的团队。

Shell
1
2
uv pip install megatron-core
torchrun --nproc-per-node=2 examples/run_simple_mcore_train_loop.py
分片与状态管理
ZeRO

ZeRO(Zero Redundancy Optimizer)的核心思想是:把数据并行中本来每卡都复制一份的三类状态(优化器状态、梯度、参数)分片到不同 rank,从而把“模型状态显存”从 O(N) 降到 O(N/world_size)。DeepSpeed 的 ZeRO Stage 1/2/3 分别对应分片优化器状态、再分片梯度、再分片参数;Stage 3 需要在前向/反向时做参数聚合与再分片。

参数分片

参数分片常见两条路线:DeepSpeed ZeRO-3 与 PyTorch FSDP。两者目标一致,但接入点与约束不同;工程选型通常取决于现有训练栈(Transformers/Accelerate 生态 vs 自定义训练框架)、offload 需求,以及 checkpoint 与集群拓扑迁移的要求。

优化器状态分片

优化器状态(例如 Adam 的一阶/二阶动量)往往占据巨大的显存/内存。ZeRO-1/2 对这部分的分片收益很直接;当显存仍不足时,DeepSpeed 还支持把状态 offload 到 CPU/NVMe(ZeRO-Offload/ZeRO-Infinity),但带宽会成为新的瓶颈,必须依赖重叠与流水化来降低代价。

CUDA 软件栈
CUDA

CUDA 版本与驱动版本的不匹配是训练系统最常见的部署故障源。最小的工程实践是区分两个事实:nvidia-smi 反映的是驱动能力,nvcc 反映的是 toolkit;二者不一致并不必然是错误,但 toolkit 版本若高于驱动支持上限就无法正常工作。部署时以 NVIDIA 的 CUDA Compatibility 文档为准,并用 PyTorch 的 torch.version.cuda 与运行时实际 driver 做交叉验证。

Shell
1
2
nvidia-smi
nvcc --version

Python
1
2
3
import torch
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
cuDNN

cuDNN 是卷积、归一化、注意力等基础算子的关键实现来源之一。训练部署阶段更重要的工作是保证:驱动、CUDA toolkit、cuDNN 版本与 GPU 架构落在官方支持矩阵内,并与 PyTorch 及扩展库(FlashAttention、xFormers、bitsandbytes)的编译参数保持一致。

NCCL

NCCL 是 NVIDIA GPU 场景下最常用的分布式通信后端。大规模训练里,通信问题往往表现为:hang、极慢、或者跨机带宽只有理论值的一小部分。排障的第一入口是 NCCL 环境变量日志与网络接口选择。

NCCL 常用环境变量
变量 作用 示例
NCCL_DEBUG 开启 NCCL 日志(INFO/WARN)
Shell
1
export NCCL_DEBUG=INFO
NCCL_DEBUG_SUBSYS 按子系统过滤 NCCL_DEBUG 输出
Shell
1
export NCCL_DEBUG_SUBSYS=INIT,NET
NCCL_SOCKET_IFNAME 指定/过滤用于通信的网卡接口(支持 include/exclude 语法)
Shell
1
2
3
export NCCL_SOCKET_IFNAME=eth
export NCCL_SOCKET_IFNAME==eth0,eth1
export NCCL_SOCKET_IFNAME=^docker
NCCL_IB_DISABLE 显式禁用 InfiniBand(在 IB 配置不完整时可用于快速隔离问题)
Shell
1
export NCCL_IB_DISABLE=1
NCCL_P2P_DISABLE 禁用 GPU P2P(用于排查 P2P/拓扑相关问题,性能通常会下降)
Shell
1
export NCCL_P2P_DISABLE=1
kernel 与算子级优化
PyTorch SDPA(内置注意力后端选择)

在不引入额外依赖的情况下,优先使用 PyTorch 的 torch.nn.functional.scaled_dot_product_attention。它会在支持时选择更高性能的 attention 后端(例如 FlashAttention / memory-efficient / cuDNN / math 实现),并将“后端差异”收敛到同一 API 上。后端选择也可通过上下文管理器显式控制。

Python
1
2
3
4
5
6
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
 
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
Triton

Triton 是面向 GPU kernel 的 Python DSL:通过 @triton.jit 与 triton.language(tl.*)API,把常见内核模式写成可编译的 Python 函数。工程上 Triton 常作为“自定义融合算子”的落地点:当 PyTorch 原生算子组合产生大量中间张量或访存瓶颈时,用 Triton 把多步计算融合成一个 kernel。

安装
Shell
1
pip install triton
最小 Triton kernel 骨架
Python
1
2
3
4
5
6
7
8
9
10
11
import triton
import triton.language as tl
 
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements: tl.constexpr, BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n_elements
    x = tl.load(x_ptr + offs, mask=mask, other=0.0)
    y = tl.load(y_ptr + offs, mask=mask, other=0.0)
    tl.store(out_ptr + offs, x + y, mask=mask)
FlashAttention

FlashAttention 是“精确 softmax attention”的高性能实现:通过 IO-aware 的分块与融合,把注意力的访存与中间张量开销显著压低。工程上常见三条接入路径:

  • 直接使用 PyTorch SDPA,让 PyTorch 在运行时选择 FlashAttention 后端(不额外引入 Python 包)。
  • 引入 flash-attn 包,显式调用算子。
  • 通过 xFormers 的 attention ops 或上层框架开关间接启用。

显式安装 FlashAttention 涉及 CUDA 扩展编译,部署约束主要集中在:CUDA toolkit、GPU 架构与编译工具链一致性。

安装
Shell
1
pip install flash-attn
  • 编译型依赖:安装过程可能会编译 CUDA 扩展,通常需要可用的 CUDA toolkit、以及可工作的编译链(例如 ninja)。
  • 版本约束:FlashAttention 的不同分支/包对 PyTorch 与 CUDA 版本有明确要求;环境固定时以官方 README 的支持矩阵为准。
  • 平台差异:Linux 是最常见的稳定路径;Windows/非常规组合通常更容易落到源码编译与 ABI 问题上。
使用(算子级接入)
Python
1
2
3
4
from flash_attn import flash_attn_func
 
# q,k,v: (batch, seqlen, nheads, headdim) 等布局依赖具体函数签名
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
flash-linear-attention(fla)

flash-linear-attention(常见简称 fla)提供线性注意力与相关模块的高性能实现,核心依赖是 PyTorch 与 Triton。工程上它更像一个“可插拔的层/算子库”:只有当模型架构实际使用了这些层(例如某些线性注意力/SSM/hybrid 模块)时,训练与推理才会受益。

安装
Shell
1
2
3
4
pip install flash-linear-attention
 
# 仅安装核心 kernel/ops(更轻依赖)
pip install fla-core
升级约束
Shell
1
2
3
# 升级前,先卸载两个包避免版本冲突
pip uninstall fla-core flash-linear-attention -y
pip install -U flash-linear-attention fla-core
其他加速组件
xFormers

xFormers 提供一组可组合的 Transformer 组件与优化算子。其中最常见的工程入口是 memory-efficient attention:通过统一接口选择不同高性能后端。安装上通常优先用预编译 wheel;当 PyTorch 版本或 CUDA 组合偏离主流时,才会退回到从源码编译。

Shell
1
pip install xformers

Python
1
2
3
4
from xformers.ops import memory_efficient_attention
 
# q,k,v 的形状与布局取决于 xFormers 版本与具体 backend
out = memory_efficient_attention(q, k, v)
bitsandbytes

bitsandbytes 提供低比特量化算子与 8-bit/4-bit 训练组件,常见用途是:QLoRA 的 4-bit Linear 层,以及 8-bit Adam 优化器状态以降低显存/内存占用。工程上它的关键点是:安装与平台兼容(CUDA/ROCm/CPU 路径)、以及把模型中的 Linear/Embedding 替换为 bnb 对应模块。

Shell
1
pip install bitsandbytes
常用模块速查
模块/类 用途 最小用法
bitsandbytes.nn.Linear4bit QLoRA 4-bit Linear
Python
1
2
3
4
5
6
7
import torch.nn as nn
from bitsandbytes.nn import Linear4bit
 
fp16_model = nn.Linear(64, 64)
q_model = Linear4bit(64, 64)
q_model.load_state_dict(fp16_model.state_dict())
q_model = q_model.to(0)  # 量化通常在 .to("cuda") 触发
bitsandbytes.nn.Linear8bitLt 8-bit Linear
Python
1
2
from bitsandbytes.nn import Linear8bitLt
layer = Linear8bitLt(4096, 4096).to(0)
bitsandbytes.optim.Adam8bit 8-bit Adam 优化器
Python
1
2
import bitsandbytes as bnb
opt = bnb.optim.Adam8bit(model.parameters(), lr=1e-4, min_8bit_size=16384)
数值精度与显存策略
数值精度

混合精度(AMP)是现代训练的默认手段:用 torch.amp.autocast 让部分算子在低精度执行,同时在需要数值范围的地方保留 FP32。对于 FP16 训练,通常需要 torch.amp.GradScaler 做梯度缩放;对于 BF16 训练,很多场景只用 autocast 即可。

Python
1
2
3
4
5
6
7
8
9
10
11
import torch
 
scaler = torch.amp.GradScaler("cuda")
 
for batch in loader:
    opt.zero_grad(set_to_none=True)
    with torch.amp.autocast("cuda", dtype=torch.float16):
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
TF32(矩阵乘加路径)

TF32 只影响部分 FP32 矩阵乘(matmul/conv)在 Tensor Core 上的执行路径。它属于“性能换数值精度”的系统开关,通常通过 PyTorch 的 backend 选项控制。

Python
1
2
3
4
5
6
7
import torch
 
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
 
# 新版本也可用更高层的 matmul 精度策略
torch.set_float32_matmul_precision("high")
训练稳定性策略
重算策略

Activation checkpointing(重算)用计算换显存:前向不保存中间激活,反向时按需重新执行前向片段。PyTorch 的直接入口是 torch.utils.checkpoint.checkpoint。工程上需要明确它的副作用:重算会改变“前向执行次数”,涉及 RNG 或跨设备拷贝的代码必须被审计,否则可能出现非确定性或性能退化。

Python
1
2
3
4
5
6
7
import torch
from torch.utils.checkpoint import checkpoint
 
def block(x):
    return layer2(layer1(x))
 
y = checkpoint(block, x, use_reentrant=False)
量化训练与低比特微调

低比特微调常见路线是 QLoRA:权重以 4-bit 存储,计算与梯度在更高精度上进行,训练的增量参数由 LoRA 承担。工程落地通常依赖 bitsandbytes 的 4-bit Linear 与上层微调框架(PEFT/Transformers);底层约束集中在:GPU/驱动/CUDA 兼容、量化算子是否可用、以及与 FSDP/ZeRO 的组合边界。

工程组合边界
  • 4-bit 权重 + 分片:参数分片(FSDP/ZeRO-3)与 4-bit 权重量化都在改写参数表示,组合时需要确认框架对“量化权重的 all-gather/重分片”路径是否支持。
  • offload:把模型状态 offload 到 CPU/NVMe 会引入额外带宽瓶颈,必须配合 micro-batch、重算与通信重叠,否则吞吐会显著下降。
  • 验证方式:先在单机单卡确认量化算子可用,再扩展到单机多卡,最后扩展到多机,逐层隔离问题源。
模型交换、导出与部署格式

“能训练”与“能部署”之间隔着一条很长的工程链路:模型从训练框架导出成某种中间表示,再由运行时加载并在特定硬件上执行。真正决定链路质量的是导出语义是否稳定、后端算子是否覆盖、部署环境是否可复现。

这一节按部署侧常见的三条路线组织:

  • 通用交换:PyTorch → ONNX → ONNX Runtime
  • NVIDIA GPU 推理:PyTorch/ONNX → TensorRT 或 TensorRT-LLM
  • Intel CPU/iGPU 推理:ONNX → OpenVINO(可选 IR 转换)→ OpenVINO Model Server

最后补齐两类“本地模型分发格式”:safetensors(安全权重)、GGUF/GGML(llama.cpp 系列推理栈),以及 Hugging Face Hub 的下载与离线部署。

ONNX

ONNX(Open Neural Network Exchange)是交换层:把训练框架里的前向计算图与权重,以跨框架可读的形式表达出来。部署侧更关心两个版本概念:IR version(中间表示版本)与 opset version(算子集合版本)。opset 变更意味着算子语义或签名变化,直接影响“模型能否被某个 runtime 正确执行”。

安装
Shell
1
pip install onnx
导出:PyTorch → ONNX(推荐 torch.export / TorchDynamo 路线)

PyTorch 的 ONNX 导出路线在持续演进。工程上更推荐走 torch.export/TorchDynamo 为基础的导出路径(例如 torch.onnx.export(..., dynamo=True)),以获得更稳定的图捕获与更好的算子覆盖。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
 
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(8, 4)
 
    def forward(self, x):
        return torch.relu(self.linear(x))
 
model = MyModel().eval()
x = torch.randn(1, 8)
 
torch.onnx.export(
    model,
    (x,),
    "my_model.onnx",
    input_names=["x"],
    output_names=["y"],
    dynamo=True,  # 推荐的新导出逻辑
)

模型权重很大时,需要考虑外部权重(external data):单个 ONNX 文件存在体积限制,导出时可以把权重拆到额外文件中,并让 ONNX 图引用它们。

ONNX 基础验证

部署前至少做两步静态检查:载入与 checker。最常见的失败来自 opset 不匹配、导出遗漏常量折叠、或动态控制流无法被捕获。

Python
1
2
3
4
import onnx
 
m = onnx.load("my_model.onnx")
onnx.checker.check_model(m)
常见坑
  • opset 不匹配:导出使用了较新的 opset,但 runtime/后端只支持较旧 opset,表现为“能 load 但执行时报不支持的算子/属性”。
  • 动态形状:ONNX 本身可以表达动态维度,但后端是否支持、以及是否需要 shape inference/优化,是另一回事。实践里建议先固定 batch/seq 长度跑通,再逐步放开。
  • 大模型外部权重:超过单文件限制时,ONNX 可能以 external data 形式拆成多文件。部署与转换时必须保证目录结构完整,并确保 runtime 能在正确的 base_dir 下加载外部权重文件。
ONNX Runtime

ONNX Runtime(ORT)是执行层:加载 ONNX 图并在不同硬件后端上执行。它通过 Execution Providers(EP)把算子下沉到不同加速库(CPU、CUDA、TensorRT 等)。部署编程上,核心对象是 InferenceSession。

安装(CPU / GPU)

实践里同一个 Python 环境通常只安装一个 ORT 包(CPU 或 GPU)。GPU 包覆盖大部分 CPU 功能,但仍需要关注 CUDA/cuDNN 与驱动版本匹配。

Shell
1
2
3
4
5
# CPU
pip install onnxruntime
 
# GPU(默认 CUDA 12.x)
pip install onnxruntime-gpu
最小可用推理代码
Python
1
2
3
4
5
6
7
8
import numpy as np
import onnxruntime as ort
 
sess = ort.InferenceSession("my_model.onnx")
 
# 输入名来自导出时的 input_names
x = np.random.randn(1, 8).astype(np.float32)
y = sess.run(None, {"x": x})
Execution Provider 选择与回退

服务端通常需要显式选择 EP,并提供“失败回退到 CPU”的策略。最常见的做法是按优先级传入 providers 列表。

Python
1
2
3
4
5
6
7
8
import onnxruntime as ort
 
providers = [
    "CUDAExecutionProvider",
    "CPUExecutionProvider",
]
 
sess = ort.InferenceSession("my_model.onnx", providers=providers)
常用对象与接口
对象 / API 用途 典型用法
ort.InferenceSession 加载模型、选择 EP、执行推理
Python
1
2
3
4
sess = ort.InferenceSession(
    "my_model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
sess.get_inputs() 枚举输入名、dtype、shape(用于接入层校验)
Python
1
2
3
inputs = sess.get_inputs()
for i in inputs:
    print(i.name, i.type, i.shape)
sess.run 执行推理
Python
1
outputs = sess.run(None, {"x": x_np})
常见坑
  • CUDA 版本对齐: onnxruntime-gpu 与本机 CUDA/cuDNN/驱动组合必须匹配,否则会出现 provider 初始化失败或动态库缺失。
  • 输入 dtype:推理时 numpy 的 dtype 必须与模型输入一致(例如 fp32),否则会报类型不匹配。
  • EP 支持度:同一个模型在不同 EP 上算子覆盖不同,部署前需要用真实模型做冒烟测试,遇到不支持算子时要么回退到 CPU,要么调整导出图/替换算子。
TensorRT

TensorRT 是 NVIDIA GPU 上的推理优化与运行时:它把 ONNX 模型解析到网络图,再由 Builder 构建优化后的 engine(plan)。构建通常离线完成,线上只加载 engine 并执行。

安装

TensorRT 提供多种安装方式(容器、Debian、pip wheel)。工程上常见的两条路径是:

  • 开发环境:用 pip 安装 tensorrt(或精简运行时变体),配合本机 CUDA 与驱动。
  • 生产环境:以容器为主,把驱动与 CUDA 依赖固化在镜像与运行时约束里。
Shell
1
2
3
4
5
6
7
8
# pip 安装(示例:按实际平台与版本选择合适包名)
python -m pip install -U pip
 
# 通用包名(可能会按平台自动选择合适变体)
pip install tensorrt
 
# 精简 runtime(适合镜像瘦身;示例)
pip install tensorrt_lean
ONNX → TensorRT engine(两种入口)

TensorRT 有两个常见入口:命令行工具(快速验证)与 Python/C++ API(集成到服务)。典型流程是用 TensorRT ONNX parser 导入 ONNX,再由 Builder 生成 engine。

1) 命令行(trtexec)
Shell
1
2
3
4
5
6
7
8
9
# 典型用法:从 ONNX 构建 engine
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan
 
# 若需要 FP16
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan --fp16
 
# 动态 shape(示例:按你的真实输入名与维度填写)
trtexec --onnx=my_model.onnx --saveEngine=my_model.plan \\
  --minShapes=x:1x8 --optShapes=x:16x8 --maxShapes=x:64x8
2) Python API(OnnxParser + Builder)
Python
1
2
3
4
5
6
7
8
9
10
11
12
import tensorrt as trt
 
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
 
with open("my_model.onnx", "rb") as f:
    ok = parser.parse(f.read())
    if not ok:
        # parser.num_errors / parser.get_error(i) 可用于定位不支持算子
        raise RuntimeError("ONNX parse failed")
常见坑
  • engine 可移植性:TensorRT engine 受平台、GPU 架构、TensorRT 版本与构建参数影响。需要跨版本或跨架构复用时,必须显式开启对应的兼容模式;否则默认情况下不具备可移植性。
  • 动态形状与 profile:动态 shape 通常需要显式设置 optimization profile,否则构建或运行会失败。
  • 算子覆盖:ONNX parser 报错时,优先从“导出是否落到标准算子”排查;其次考虑 TRT 插件或改写模型。
TensorRT-LLM

TensorRT-LLM 是面向 LLM 的 TensorRT 构建与运行时栈:提供 Python API 和服务端组件,把 LLM 的 KV cache、注意力优化、量化与服务化接口封装成一条更完整的部署链。快速落地通常走官方容器路线,然后用 trtllm-serve 启动 OpenAI-compatible server。

安装与启动(推荐容器路线)

TensorRT-LLM 更偏“完整部署栈”,依赖 CUDA、TensorRT、编译链与模型支持矩阵。工程上优先选择官方预构建容器,在容器内完成转换、build 与 serve。

在线部署:trtllm-serve(OpenAI-compatible)
Shell
1
2
# 容器内启动服务(示例)
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

启动后可访问标准 OpenAI 端点(例如 /v1/chat/completions)。

Shell
1
2
3
4
5
6
7
8
curl -X POST http://localhost:8000/v1/chat/completions \\
  -H "Content-Type: application/json" \\
  -d '{
    "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "messages": [{"role": "user", "content": "Where is New York?"}],
    "max_tokens": 32,
    "temperature": 0
  }'
离线推理:LLM API

TensorRT-LLM 同时提供 Python 侧的 LLM API:给定 Hugging Face repo 或 checkpoint,API 负责加载、优化与推理编排。对工程团队而言,这条路径适合把“推理服务”嵌入到现有 Python 服务栈中,但需要更细致的版本与环境锁定。

OpenVINO

OpenVINO 面向 Intel CPU/iGPU/加速器的推理栈。它既能直接加载 ONNX,也能把 ONNX 转换成 OpenVINO IR(xml+bin)。如果关注加载延迟或希望提前做图优化,通常会先把 ONNX 转成 IR。

安装
Shell
1
pip install openvino
ONNX → OpenVINO IR(Python API)
Python
1
2
3
import openvino as ov
 
ov_model = ov.convert_model("your_model_file.onnx")
OpenVINO Model Server:LLM QuickStart

OpenVINO Model Server(OVMS)提供服务化部署。它支持以 Docker 启动,指定 --source_model 从 Hugging Face 拉取已转换的 OpenVINO 模型,并暴露 OpenAI 风格 API。

Shell
1
2
3
4
5
6
7
8
9
mkdir -p models
docker run -d --rm -p 8000:8000 \\
  -v $(pwd)/models:/models:rw \\
  openvino/model_server:2026.1-gpu \\
  --source_model OpenVINO/Qwen3-8B-int4-ov \\
  --model_repository_path models \\
  --task text_generation \\
  --rest_port 8000 \\
  --target_device GPU

OVMS 也支持用 OpenAI Python client 直接调用(base_url 指向 OVMS)。

常见坑
  • 外部权重 ONNX:若 ONNX 以 external data 拆成多文件,必须保持主 onnx 与外部权重文件的目录关系可发现。
  • 硬件选择:OVMS 需要正确设置 target_device,并保证容器或主机具备对应设备节点与驱动。
本地推理权重格式
safetensors

safetensors 的工程定位是“替代 pickle 的安全权重格式”:可做 zero-copy 加载,并显式避免任意代码执行风险。它主要服务于权重分发与加载,不承担跨硬件优化执行这一层职责。

Shell
1
pip install safetensors

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from safetensors import safe_open
from safetensors.torch import save_file
import torch
 
# save
tensors = {"w": torch.zeros((2, 2))}
save_file(tensors, "model.safetensors")
 
# load (zero-copy)
loaded = {}
with safe_open("model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        loaded[k] = f.get_tensor(k)

按 key 读取(或切片读取)是 safetensors 的常见用法:多 GPU 分片加载、按需加载 embedding 表等场景,会用它降低峰值内存。

GGML

GGML 是 llama.cpp 早期使用的权重格式/生态名词之一。当前工程实践中,GGUF 更常作为“可分发的最终产物”。GGML 更适合理解为历史兼容路径:遇到旧模型时需要能识别与迁移。

GGUF

GGUF 是 llama.cpp 生态的主流分发格式。典型链路是:从 Hugging Face 模型(safetensors/pytorch)转换到 GGUF,再选择量化方案,最后交给 llama.cpp 或 Ollama 运行。llama.cpp 仓库提供了 convert_hf_to_gguf.py 等脚本作为转换入口。

Shell
1
2
3
4
5
6
7
8
9
# 典型流程:先克隆 llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
 
# 按仓库要求安装转换脚本依赖(示例)
pip install -r requirements.txt
 
# 转换(脚本参数以 --help 为准)
python convert_hf_to_gguf.py --help

GGUF 的常见坑是 tokenizer 与特殊 token:转换时必须保证 tokenizer 文件齐全(例如 sentencepiece model 或 BPE merges/vocab),否则会出现“能推理但输出严重异常”的隐蔽故障。

模型获取与分发工具
huggingface_hub

huggingface_hub 是下载与缓存的编程入口:它把“下载”变成版本化缓存,并返回本地路径。缓存路径指向的文件不应该被修改,否则会污染缓存并产生难以排查的线上问题。

Shell
1
pip install huggingface_hub
API / CLI 用途 示例
hf_hub_download 下载单个文件(带缓存与 revision)
Python
1
2
3
4
5
6
from huggingface_hub import hf_hub_download
path = hf_hub_download(
    repo_id="lysandre/arxiv-nlp",
    filename="config.json",
    revision="main",
)
snapshot_download 下载整个仓库(支持 allow/ignore patterns)
Python
1
2
3
4
5
6
7
from huggingface_hub import snapshot_download
 
local_path = snapshot_download(
    repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    revision="main",
    allow_patterns=["*.safetensors", "*.json"],
)
HfApi(endpoint=...) 对接私有 Hub/镜像,显式指定 endpoint
Python
1
2
3
4
from huggingface_hub import HfApi
 
api = HfApi(endpoint="https://huggingface.co")
models = api.list_models(search="bert")
hf_hub_url 构造下载 URL(用于调试/审计)
Python
1
2
from huggingface_hub import hf_hub_url
url = hf_hub_url("lysandre/arxiv-nlp", "config.json")
hf CLI 登录、下载、缓存管理
Shell
1
2
3
4
5
6
7
# 安装与查看
hf --help
 
# 按 revision 固定下载
hf download TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\
  --revision main \\
  --include "*.safetensors" --exclude "*.bin"

huggingface_hub 的环境变量用于把缓存与认证变成可运维配置: HF_HOME、 HF_HUB_CACHE、 HF_TOKEN 等。它们通常在 import 时读取,生产环境必须保证“进程启动前配置好”。

离线模式与版本固定

离线部署首先要把所有文件按 revision 固定并落到可控目录,然后再在离线环境启用 offline 开关。仅仅让 from_pretrained 在离线环境里可调用,并不足以保证整条部署链稳定。

Shell
1
2
3
4
5
6
7
8
9
# 1) 先下载到可控缓存目录(示例)
export HF_HOME=/data/hf
export HF_HUB_CACHE=/data/hf/hub
hf download TinyLlama/TinyLlama-1.1B-Chat-v1.0 --revision main
 
# 2) 再把运行环境切到离线(示例:以你实际依赖版本为准)
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
export HF_DATASETS_OFFLINE=1
  • 对关键模型,revision 固定到 tag/commit hash,并把下载产物做 manifest(文件列表 + 哈希)。
  • 离线开关:Hugging Face 生态里存在多层 offline 变量(例如 hub、transformers、datasets 各自的 offline 模式)。团队需要以“实际依赖版本”为准做一次演练,确认每个库在离线环境的行为一致。
hf-mirror 与中国大陆下载

中国大陆环境常见问题是“访问 Hugging Face 资源不稳定”。工程上有两条可控路径:

  • 内部镜像:自建 mirror 服务,把下载变成内网依赖。
  • 第三方镜像:通过环境变量把下载 base url 指向镜像站点。

对镜像/私有 Hub 的对接,优先使用显式 endpoint:Python 侧使用 HfApi(endpoint=...);命令行侧则使用统一的网络出口与缓存目录策略。部分旧版本文档中也记录了 HF_ENDPOINT 这类环境变量用法,但它是否生效取决于你实际安装的 huggingface_hub 版本。

Shell
1
2
# 把 Hugging Face 的下载端点切到镜像(示例)
export HF_ENDPOINT=https://hf-mirror.com

需要评估:hf-mirror 属于第三方服务,可靠性与合规性需要业务自行评估。能自建镜像或把模型产物纳入制品库(artifact repository)的团队,应优先选择可控方案。

典型部署链路速查
目标 链路 适用场景 常见坑
跨框架推理 PyTorch → ONNX → ONNX Runtime 多语言客户端、跨平台部署 opset/EP 不匹配,动态形状处理
NVIDIA GPU 高性能推理 PyTorch/ONNX → TensorRT engine 低延迟/高吞吐服务 engine 不可跨 GPU/版本复用,profile/插件
NVIDIA LLM 服务化 HF checkpoint → TensorRT-LLM → trtllm-serve OpenAI-compatible LLM server 模型支持矩阵、量化/精度要求、容器化依赖
Intel LLM 服务化 ONNX → OpenVINO(可选 IR)→ OVMS CPU/iGPU 部署、边缘与本地服务 外部权重、设备映射与驱动
本地量化推理 HF → GGUF → llama.cpp/Ollama 本地开发、边缘设备 tokenizer 文件缺失、量化质量与兼容性
推理引擎与服务系统

推理引擎与服务系统把“模型权重 + 推理优化”交付为“可稳定承载并发请求的 API”。服务端需要长期管理 prefill/decode 调度、KV cache 生命周期、批处理策略、流式输出、并发隔离、模型加载与热更新 等工程问题。

LLM 推理引擎

常见推理栈可以按落地点分为两类:面向 GPU 的在线推理引擎(vLLM、SGLang、TGI)与面向本地/边缘的运行时(llama.cpp、Ollama)。工程选型通常先定两件事:服务端是否提供 OpenAI-compatible API,以及是否需要多 GPU/多节点的原生支持。

vLLM

vLLM 是面向高吞吐服务端推理的引擎。它通过 PagedAttention 管理 KV cache,并采用 continuous batching 处理变长请求,从而在并发场景下维持 GPU 利用率。生产系统里最常用的入口是 OpenAI-compatible server(HTTP)。

安装与启动

安装时优先按官方指引为目标平台准备匹配的 PyTorch(CUDA/ROCm/CPU),再安装 vLLM。常见启动方式如下:

vLLM: run OpenAI-compatible server
Shell
1
2
3
4
5
6
# pip install vllm
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --host 0.0.0.0 \\
  --port 8000 \\
  --dtype auto \\
  --api-key token-abc123

容器化部署通常使用官方镜像,并把 Hugging Face 缓存目录挂载到容器内,避免重复下载权重:

vLLM: run with Docker image (pattern)
Shell
1
2
3
4
5
6
docker run --gpus all \\
  -v ~/.cache/huggingface:/root/.cache/huggingface \\
  -p 8000:8000 \\
  --ipc=host \\
  vllm/vllm-openai:latest \\
  --model meta-llama/Meta-Llama-3-8B-Instruct
OpenAI-compatible API(调用与流式输出)

OpenAI-compatible server 的目标是复用现有的 OpenAI SDK。调用方式只需要把 base_url 指向自托管服务即可:

Call vLLM with OpenAI Python SDK
Python
1
2
3
4
5
6
7
8
9
10
from openai import OpenAI
 
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
resp = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0.2,
    max_tokens=128,
)
print(resp.choices[0].message)

流式输出通常使用 SSE;请求里把 stream 设为 true 即可:

OpenAI-compatible streaming (generic curl)
Shell
1
2
3
4
5
6
7
8
curl http://localhost:8000/v1/chat/completions \\
  -H 'Content-Type: application/json' \\
  -H 'Authorization: Bearer token-abc123' \\
  -d '{
    "model": "meta-llama/Meta-Llama-3-8B-Instruct",
    "messages": [{"role":"user","content":"Hello!"}],
    "stream": true
  }'
关键服务参数(常用 Engine Args)

vLLM 的参数很多,但多数服务化场景只需要围绕“显存预算、并发上限、上下文长度、缓存开关”做控制。

参数 含义 典型影响
--max-model-len 最大上下文长度 上限越大,KV cache 预算越高;并发上限通常随之下降
--gpu-memory-utilization 显存预算比例 控制 KV cache 可用空间,影响 OOM 风险与吞吐
--max-num-batched-tokens 每步调度的 token 预算上限 增大可提升吞吐,但可能增加尾延迟
--max-num-seqs 并发序列数上限 控制并发度与资源争用,影响延迟与稳定性
--kv-cache-dtype KV cache 存储精度 更激进的 KV 精度可降低显存/带宽,但需要评估质量影响
--enable-prefix-caching 启用前缀缓存(Prompt Caching) 前缀重复多的业务可显著减少 prefill 成本
--generation-config generation_config 的优先级策略 影响默认采样参数来源,避免线上采样行为“悄悄变了”
多 GPU 与多节点部署

服务化推理的常见拓扑是“单实例多 GPU(TP/PP)”与“多副本横向扩展”。横向扩展更依赖网关/LB 做副本路由;多数引擎把 KV cache 作为进程内状态,因此跨副本共享缓存并不常见。

Serving topology (concept)
1
2
3
4
5
6
7
8
single instance:
  [client] -> [inference server] -> [1 GPU]
 
single instance, multi-GPU:
  [client] -> [inference server] -> [TP/PP over N GPUs]
 
replicas:
  [client] -> [LB / gateway] -> [replica-1] / [replica-2] / ...
SGLang

SGLang 以“推理编排能力 + OpenAI-compatible API”为核心卖点。工程上它常用于需要多步推理控制流、工具调用编排、以及对思维链/推理输出有结构化处理的在线系统。

安装与启动

SGLang 的启动入口可以是 sglang serve 或 python -m sglang.launch_server,两者本质上都是启动一个 OpenAI-compatible server。

SGLang: install & run (single node)
Shell
1
2
3
4
5
6
7
pip install -U sglang
 
python -m sglang.launch_server \\
  --model-path qwen/qwen2.5-0.5b-instruct \\
  --host 0.0.0.0 \\
  --port 30000 \\
  --log-level warning
OpenAI-compatible API 调用
Call SGLang with OpenAI Python SDK
Python
1
2
3
4
5
6
7
8
9
10
from openai import OpenAI
 
client = OpenAI(base_url="http://localhost:30000/v1", api_key="EMPTY")
resp = client.chat.completions.create(
    model="qwen/qwen2.5-0.5b-instruct",
    messages=[{"role": "user", "content": "List 3 countries and their capitals."}],
    temperature=0,
    max_tokens=64,
)
print(resp.choices[0].message)
多节点与并行参数(概念)

SGLang 支持多 GPU 与多节点部署。多节点部署通常需要显式指定节点数量、节点 rank、以及通信初始化地址;并行规模与模型大小共同决定权重切分与显存预算。

TGI(Text Generation Inference)

TGI 是 Hugging Face 的推理服务栈。它在“用 Docker 把 Transformers 模型服务化”上体验成熟,仍然适合存量系统维护与兼容性部署;但该项目在 2026-03-21 被归档为只读,新增特性与生态协同通常不如更活跃的推理引擎。

Docker Quickstart

最常见启动方式是官方 Docker 镜像,容器内默认在 80 端口提供服务,常见映射是主机 8080 → 容器 80:

TGI: Docker quickstart (pattern)
Shell
1
2
3
4
5
6
7
8
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data
 
docker run --gpus all --shm-size 1g \\
  -p 8080:80 \\
  -v $volume:/data \\
  ghcr.io/huggingface/text-generation-inference:3.3.5 \\
  --model-id $model
关键参数(text-generation-launcher)

TGI 的核心入口是 text-generation-launcher。生产里最常调整的参数集中在模型来源、多 GPU 分片与量化。

参数 含义 典型用途
--model-id 模型 ID 或本地目录 指定权重来源(Hub 或本地)
--sharded 启用多 GPU 分片 模型需要多卡容纳时
--num-shard 分片数量 控制使用多少张 GPU
--quantize 量化模式 降低显存占用与带宽压力
llama.cpp(本地与边缘推理)

llama.cpp 面向本地/边缘推理,围绕 GGUF 权重与多后端(CPU/CUDA/Metal 等)提供统一 runtime,并包含 OpenAI-compatible HTTP server( llama-server)。它常用于离线环境、边缘设备、或需要把推理能力分发到开发机的场景。

编译与启动
llama.cpp: build (CUDA pattern)
Shell
1
2
3
4
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
cmake -B build -DGGML_CUDA=ON
cmake --build build -j

llama-server: start (OpenAI-compatible)
Shell
1
2
3
4
./build/bin/llama-server \\
  -m /path/to/model.gguf \\
  --host 0.0.0.0 \\
  --port 8081
常用运行参数
参数 含义 典型影响
-m, --model GGUF 模型路径 决定权重来源与量化格式
-c, --ctx-size 上下文长度 影响 KV cache 与吞吐
-t, --threads 生成阶段 CPU 线程数 CPU 推理吞吐与尾延迟
-tb, --threads-batch 批处理/预填充阶段线程数 prefill 性能
-ngl, --n-gpu-layers offload 到 GPU 的层数 GPU/CPU 负载比例与显存占用
Ollama(本地模型分发 + 运行时)

Ollama 把本地模型拉取、版本管理与本地 API 封装成工具链。它提供原生 API( /api/generate、 /api/chat、 /api/embed),并提供 OpenAI-compatible API 作为兼容层。

安装与启动
Ollama: install on Linux (pattern)
Shell
1
curl -fsSL https://ollama.com/install.sh | sh
原生 API(/api/*)
Ollama: generate
Shell
1
2
3
4
curl http://localhost:11434/api/generate -d '{
  "model": "gemma3",
  "prompt": "Why is the sky blue?"
}'

Ollama: chat
Shell
1
2
3
4
curl http://localhost:11434/api/chat -d '{
  "model": "gemma3",
  "messages": [{"role":"user","content":"why is the sky blue?"}]
}'
OpenAI compatibility(概念)

兼容层的目标是让应用侧复用 OpenAI SDK 与中间件。工程上需要明确两类差异:端点语义的“兼容程度”(是否实现同等字段/行为),以及模型侧默认值(chat template、默认采样参数)是否与业务一致。

国产大模型部署(以 ChatGLM 为例)

国产大模型的部署风险集中在“推理栈适配”与“默认行为一致性”。常见问题包括:chat template 与 tokenizer 不一致导致对话格式错乱,上下文长度预算误判,以及 generation_config 的默认值覆盖导致采样行为偏离预期。

ChatGLM-6B:仓库自带的最小 API 服务

ChatGLM-6B 仓库提供了一个最小 FastAPI 服务端作为 API 部署入口,用于本地验证模型与服务链路:

ChatGLM-6B: minimal API server (from official repo)
Shell
1
2
3
# in THUDM/ChatGLM-6B repo
pip install fastapi uvicorn
python api.py
在线服务与接口兼容层

接口兼容层是推理系统工程化的关键:当服务端尽可能实现 OpenAI 的请求/响应形状,应用侧可以复用 SDK、网关与观测链路,只需把 base_url 指向自托管服务即可迁移。兼容层并不自动保证“行为一致”,上线前需要在同一套 prompts 与 sampling 参数下对齐输出分布与稳定性。

Triton Inference Server

Triton 是通用推理服务系统,核心抽象是模型仓库(Model Repository):服务端通过 --model-repository 指定一个或多个仓库路径,并按固定目录布局加载模型版本。它常用于集中托管 ONNX/TensorRT/自定义后端模型,并提供 HTTP/gRPC 与监控端点。

Triton: start server (pattern)
Shell
1
2
3
4
docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 \\
  -v /path/to/model-repo:/models \\
  nvcr.io/nvidia/tritonserver:<tag> \\
  tritonserver --model-repository=/models

Triton model repository layout (concept)
1
2
3
4
5
6
7
model-repo/
  my_model/
    config.pbtxt
    1/
      model.onnx
    2/
      model.onnx

LLM 场景下,Triton 常作为“统一 Serving 平台”,在其上集中部署 embedding、reranker、ASR/TTS 等子模型,或加载 TensorRT-LLM backend 承载生成模型。

调度、缓存与性能机制

推理系统的吞吐、延迟与成本与服务端调度强相关。LLM 在线推理通常由两个阶段组成:prefill(处理输入 prompt 并建立 KV)与 decode(逐 token 生成)。多数性能机制都在优化这两段的计算与内存路径。

Batching 与 Continuous Batching

batching 把多个请求合并成一次 forward 调用,提升 GPU 利用率。静态 batching 会被最长请求拖住;服务化推理通常采用 continuous batching,在 token 级别持续吸纳新请求并淘汰已完成请求,从而减少尾延迟并提升吞吐。

KV cache

KV cache 缓存历史 token 的 Key/Value,使 decode 每一步只需计算当前 token 的 Query 并与历史 KV 做注意力。代价是显存占用随上下文长度与并发序列数近似线性增长;服务端因此需要在“最大上下文长度”和“最大并发序列数”之间做显存预算分配。

Prompt Caching(前缀缓存)

前缀缓存复用“已 prefill 的前缀 KV”。当新请求与历史请求共享前缀时,服务端可以跳过重复的 prefill 计算。该机制对“系统 prompt 固定、RAG 模板固定、长前缀重复”的业务收益显著。

Speculative decoding

speculative decoding 用“草稿模型提出多个候选 token + 目标模型验证并接收其中一部分”的方式减少目标模型的 decode 步数,从而降低解码延迟。服务端通常需要同时加载目标模型与草稿模型,并为草稿模型设置独立的资源预算。

vLLM: speculative decoding (CLI pattern)
Shell
1
2
3
4
5
6
vllm serve <target-model> \\
  --speculative-config '{
    "method": "draft_model",
    "model": "<draft-model>",
    "num_speculative_tokens": 5
  }'

llama-server: speculative decoding (pattern)
Shell
1
2
3
4
./build/bin/llama-server \\
  -m /path/to/target.gguf \\
  --model-draft /path/to/draft.gguf \\
  --host 0.0.0.0 --port 8081
检索、向量与 RAG 支撑组件

RAG 工程的核心约束来自“把非结构化知识变成可检索的结构化索引”。这一层组件主要解决:如何把文档分块、生成 embedding、写入索引或向量数据库、在查询时做 ANN 召回与元数据过滤、用 reranker 提升答案相关性、以及在部署层面控制延迟、吞吐与成本。

RAG 的最小工程闭环

一条可维护的检索链路通常分成两条 pipeline:离线入库(ingestion)与在线查询(retrieval)。离线阶段负责把文档标准化、分块、embedding、建索引并落盘;在线阶段负责 query embedding、ANN 召回、过滤、重排与返回候选片段。

离线入库(Ingestion)
Python
1
2
3
4
5
# 1) 规范化文本(去掉无意义空白、统一编码、可选:去重)
# 2) chunking:长文档切成 chunk(带 overlap)
# 3) embedding:把每个 chunk 编码成 float32 向量
# 4) upsert:写入向量索引/向量数据库(同时写入 metadata)
# 5) build index:IVF/HNSW 等(有些系统是插入即维护索引)
在线检索(Retrieval)
Python
1
2
3
4
5
# 1) query -> embedding
# 2) ANN search:向量相似度召回 topK candidates
# 3) filter:按 tenant / language / source / time / ACL 等元数据过滤
# 4) rerank:Cross-Encoder / LLM rerank(可选但通常能显著提升相关性)
# 5) 返回 chunks(以及 doc_id / offsets / urls 等可追溯信息)
Chunking、Embedding、Reranker、缓存
Chunking(分块)

分块的目标是让每个向量对应一段“语义足够集中、可作为检索单位”的文本。工程上需要同时满足:检索召回稳定、下游生成可引用、以及入库成本可控。

策略 适用场景 工程要点
固定窗口(按 token/字符计数) 通用文本、日志、论坛等结构弱的语料 用 overlap 降低“切断引用”的概率;chunk_id 需要稳定(便于增量更新与去重)。
结构感知(按标题/段落/代码块) Markdown、HTML、技术文档、论文 保留层级路径(h1/h2/h3)作为 metadata;能显著改善可解释性与定位能力。
语义分段(按句子/主题边界) 长文档、语义跳跃频繁的内容 实现复杂;通常与结构感知结合更稳。
Embeddings(向量化接入)

embedding 既可以在进程内完成(直接加载模型编码),也可以通过独立服务完成(HTTP/gRPC embedding endpoint)。工程关注点包括:向量维度、向量归一化(cosine vs inner product)、批量化、以及版本管理(embedding 模型升级带来的全量重算成本)。

关注点 建议 原因
向量 dtype 索引侧统一 float32(或按系统支持使用 fp16/binary/sparse) 多数 ANN 实现以 float32 为主;混用 dtype 容易导致精度与兼容性问题。
cosine 相似度 向量 L2 归一化后用 inner product(IP) cosine(a,b) = a·b(当 ||a||=||b||=1);索引实现更统一。
版本管理 把 embedding_model_id 写入 metadata;变更时双写或重建 避免“同一集合混入不同 embedding space”导致召回不可解释。
批量化 离线入库用 batch encode + shard 写入 embedding 计算与写入都更容易成为瓶颈,批量化是最有效的吞吐优化。
Reranker(重排)

向量召回的 topK 通常是“相关但不够精确”的集合,reranker 用更强的匹配器(Cross-Encoder/LLM)对候选做二次排序。工程上常见的做法是:召回 topK=50~200,然后重排取 topN=5~20 作为上下文。

Python
1
2
3
4
# 伪代码:Cross-Encoder rerank
# pairs = [(query, chunk_text) for chunk_text in candidates]
# scores = reranker.predict(pairs)
# candidates = sort_by(scores)[:topN]
缓存与幂等(Cache & Idempotency)

RAG 系统的成本通常被 embedding 与 ANN 搜索吞掉。缓存应该围绕“纯函数”构建:相同输入应产生相同输出。

缓存点 Key 设计 注意事项
embedding 缓存 sha256(normalized_text) + model_id 同一文本在不同模型下向量不同;必须把 model_id 纳入 key。
检索结果缓存 sha256(query) + model_id + filters 适合高频固定 query;对实时性强(例如新闻)的库需要设置短 TTL 或禁用。
chunk 幂等 upsert doc_id + chunk_id 确保增量更新不会产生重复点;chunk_id 推荐来自稳定切分策略。
本地索引与嵌入式检索

本地索引适合“单机/单租户/读多写少”的场景:部署简单、延迟低、没有额外网络 hop。代价是分片、扩容与高可用需要自行实现。

FAISS
安装
Shell
1
2
# CPU
pip install -U faiss-cpu
核心接口速查
对象/函数 用途 最小用法(示例)
faiss.IndexFlatIP 精确检索(inner product);常配合归一化实现 cosine
Python
1
2
3
import faiss, numpy as np
d = 768
index = faiss.IndexFlatIP(d)
faiss.IndexIVFFlat IVF 近似检索(聚类倒排 + 精确扫描)
Python
1
2
quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFFlat(quantizer, d, 4096, faiss.METRIC_INNER_PRODUCT)
faiss.normalize_L2 向量 L2 归一化
Python
1
faiss.normalize_L2(x)  # x: float32 [n, d]
write_index / read_index 索引落盘与加载
Python
1
2
faiss.write_index(index, "docs.faiss")
index = faiss.read_index("docs.faiss")
建库与查询(最小示例)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import faiss
import numpy as np
 
# xb: float32 [n, d],xq: float32 [m, d]
xb = np.random.randn(10000, 768).astype("float32")
xq = np.random.randn(10, 768).astype("float32")
 
# cosine: 归一化 + inner product
faiss.normalize_L2(xb)
faiss.normalize_L2(xq)
 
index = faiss.IndexFlatIP(768)
index.add(xb)
 
scores, ids = index.search(xq, k=5)  # ids: [m, k]
工程权衡

FAISS 的“强项”是速度与可控性:你可以精确控制索引结构、nprobe/efSearch 等参数,并把索引作为本地文件交付。它的“短板”是系统能力:过滤、权限、在线扩缩容、分片与持久化策略都需要额外工程。

pgvector(PostgreSQL 向量扩展)
安装与启用
Shell
1
2
3
# 方式很多(源码/包管理器/Docker/托管服务)
# 核心步骤是:在目标数据库里启用扩展
psql -d your_db -c "CREATE EXTENSION IF NOT EXISTS vector;"
建表、建索引与查询(SQL)
SQL (pgvector)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
CREATE TABLE doc_chunks (
  id bigserial PRIMARY KEY,
  doc_id text NOT NULL,
  chunk_id int NOT NULL,
  chunk text NOT NULL,
  embedding vector(768) NOT NULL
);
 
-- HNSW(适合低延迟近似检索,内存开销更高)
CREATE INDEX ON doc_chunks USING hnsw (embedding vector_cosine_ops);
 
-- 查询 topK(cosine distance)
SELECT doc_id, chunk_id, chunk
FROM doc_chunks
ORDER BY embedding <=> '[0.01, -0.02, ...]'
LIMIT 10;
工程权衡

pgvector 的优势是“把向量检索融进现有 OLTP/OLAP 体系”:事务、JOIN、权限、备份与监控都沿用 Postgres。代价是向量检索性能上限通常低于专用向量数据库,尤其是高维大规模与高 QPS 场景;此外还需要对索引参数、VACUUM/ANALYZE、以及冷热数据分层有明确策略。

专用向量数据库

专用向量数据库把“高维 ANN 检索 + 元数据过滤 + 持久化 + 分布式扩展”做成标准能力,适合多租户、数据规模持续增长、需要高可用与可观测性的场景。

Qdrant
部署(本地 Docker)
Shell
1
2
3
4
docker pull qdrant/qdrant
docker run -p 6333:6333 -p 6334:6334 \
  -v "$(pwd)/qdrant_storage:/qdrant/storage:z" \
  qdrant/qdrant
安装(Python Client)
Shell
1
2
# 可选 fastembed:在 client 侧直接做 text -> embedding(适合快速验证)
pip install -U "qdrant-client[fastembed]"
建库、写入与查询(Python)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
 
client = QdrantClient(url="http://localhost:6333")
 
if not client.collection_exists("chunks"):
  client.create_collection(
    collection_name="chunks",
    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
  )
 
# upsert points(vector + payload)
points = [
  PointStruct(id=1, vector=[0.0] * 768, payload={"doc_id": "d1", "lang": "zh"}),
]
client.upsert(collection_name="chunks", points=points)
 
# search + filter
f = Filter(must=[FieldCondition(key="lang", match=MatchValue(value="zh"))])
hits = client.search(collection_name="chunks", query_vector=[0.0] * 768, limit=5, query_filter=f)
工程权衡

Qdrant 的工程特点是:数据模型清晰(point + payload)、过滤与索引能力成熟、部署路径明确(本地 Docker / Helm / Cloud)。在安全层面需要显式启用鉴权与网络隔离,默认容器配置通常是无认证的开发模式。

Milvus
部署(Docker Compose)
Shell
1
2
curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh
bash standalone_embed.sh start
安装(Python SDK)
Shell
1
pip install -U pymilvus
建库、建索引与查询(Python)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from pymilvus import MilvusClient, DataType
 
client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
 
schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=768)
schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=256)
 
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
  field_name="vector",
  index_type="HNSW",
  metric_type="COSINE",
  params={"M": 16, "efConstruction": 200},
)
 
client.create_collection(collection_name="chunks", schema=schema, index_params=index_params)
 
client.insert(
  collection_name="chunks",
  data=[{"id": 1, "vector": [0.0] * 768, "doc_id": "d1", "lang": "zh"}],
)
 
hits = client.search(
  collection_name="chunks",
  data=[[0.0] * 768],
  limit=5,
  filter='doc_id == "d1"',
  output_fields=["doc_id"],
)
工程权衡

Milvus 的定位是面向大规模 ANN 检索的工程化系统:索引类型丰富、可分布式扩展、并提供集合(collection)层的 schema 与字段能力。它对部署与运维的要求也更高,适合“数据规模持续增长且需要系统化治理”的团队。

TCVectorDB(腾讯云向量数据库)

TCVectorDB 属于托管型向量数据库:实例创建、扩缩容与高可用由云平台提供,SDK 把 HTTP API 封装成 Python 类与对象模型。工程上更关注鉴权、网络连通(VPC/公网)、以及数据模型与索引类型的选择。

安装
Shell
1
pip3 install -U tcvectordb
核心接口速查
对象/概念 用途 工程含义
Client SDK 主入口(鉴权、请求与资源管理) 通常需要配置 endpoint/region 与密钥;建议把 credential 放在环境变量或密钥管理系统。
Database / Collection 逻辑组织结构 多租户场景下可按业务线或数据域拆分;collection 内的维度与 metric 必须一致。
IndexType / MetricType 索引类型与相似度度量 决定 recall/latency/cost 的核心旋钮;写入量大时要关注索引构建与更新开销。
写入与检索(工作流骨架)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tcvectordb
from tcvectordb.model.collection import Embedding
from tcvectordb.model.enum import FieldType, IndexType, MetricType, EmbeddingModel, ReadConsistency
from tcvectordb.model.index import VectorIndex, FilterIndex, HNSWParams
 
tcvectordb.debug.DebugEnable = False
 
client = tcvectordb.RPCVectorDBClient(
  url="https://<your-vdb-endpoint>",    # 控制台/文档提供
  key="<your-api-key>",
  username="root",
  read_consistency=ReadConsistency.EVENTUAL_CONSISTENCY,
  timeout=30,
)
 
db = "rag_db"
col = "chunks"
client.create_database_if_not_exists(database_name=db)
 
# 让服务端基于文本字段自动生成向量(EmbeddingModel 为 SDK 内置枚举)
ebd = Embedding(vector_field="vector", field="chunk", model=EmbeddingModel.BGE_BASE_ZH)
client.create_collection_if_not_exists(
  database_name=db,
  collection_name=col,
  shard=1,
  replicas=0,
  indexes=[
    FilterIndex(name="id", field_type=FieldType.String, index_type=IndexType.PRIMARY_KEY),
    VectorIndex(
      name="vector",
      field_type=FieldType.Vector,
      index_type=IndexType.HNSW,
      dimension=768,
      metric_type=MetricType.COSINE,
      params=HNSWParams(m=16, efconstruction=200),
    ),
    FilterIndex(name="doc_id", field_type=FieldType.String, index_type=IndexType.FILTER),
  ],
  embedding=ebd,
)
 
# upsert:只写文本与元数据,由 embedding 配置自动产出 vector
client.upsert(
  database_name=db,
  collection_name=col,
  documents=[{"id": "c1", "doc_id": "d1", "chunk": "向量数据库用于相似度检索…"}],
)
 
# 检索:search_by_text -> 先做 embedding,再做 ANN search
hits = client.search_by_text(
  database_name=db,
  collection_name=col,
  embedding_items=["向量数据库"],
  output_fields=["doc_id"],
  limit=5,
)
工程权衡

托管型服务的收益来自运维外包与 SLA,代价来自云厂商绑定与成本结构(存储、QPS、流量、索引构建)。当业务对数据主权、可迁移性或自定义算子有强需求时,需要评估本地自建或可移植方案(例如 pgvector 或自建 Qdrant/Milvus)。

选型与权衡(工程视角)
方案 适合 不适合
FAISS(本地索引) 单机部署、离线构建索引、极低延迟、对索引结构控制强 需要复杂过滤/权限/多租户/高可用与在线扩缩容
pgvector(Postgres 内嵌) 已有 Postgres 体系、需要事务与 JOIN、数据规模中等 超大规模 ANN + 高 QPS 的专用检索场景
Qdrant / Milvus(自建向量库) 需要过滤、持久化、分布式扩展与稳定运维 团队缺少运维能力、或希望把运维成本完全外包
TCVectorDB(托管向量库) 希望快速上线并获得云端 SLA、对云集成友好 强可迁移性需求、或需要深度定制与自托管
Agent、工具与应用编排组件

Agent 编排层解决的是“把模型调用变成可执行系统”的工程问题:任务被拆成哪些步骤、每一步调用哪个模型、工具如何注册与授权、状态如何持久化、失败如何重试、以及如何把整条调用链暴露给可观测性系统。它位于推理引擎与训练框架之上,承担流程控制、工具集成与状态管理。

从部署视角看,Agent 系统至少包含三类进程:

  • 推理后端:提供模型推理 API(OpenAI、vLLM、SGLang、TGI、TensorRT-LLM 等)。
  • 编排运行时:实现状态机/图/循环,负责发起模型调用、路由与错误处理。
  • 工具服务:把外部能力(数据库、搜索、浏览器、业务 API、文件系统)封装为工具端点,供模型以 tool calling 方式触发。
Agent 编排框架

编排框架的差异主要体现在两点:控制流的表达能力(链式、图式、事件驱动、角色流水线),以及工具调用的边界管理(schema、权限、审批、重试、隔离)。

LangChain

LangChain 更适合把模型、提示词、检索器与工具快速组装成可运行的 pipeline。它的核心抽象是“可组合组件”,典型用法是先把模型初始化成统一接口,再用可组合表达把上下游粘起来。

Shell
1
2
pip install -U langchain langchain-openai
export OPENAI_API_KEY=sk-...

Python
1
2
3
4
5
from langchain.chat_models import init_chat_model
 
model = init_chat_model("openai:gpt-5.4")
result = model.invoke("Hello, world!")
print(result)

LangChain 负责“编排代码结构”,推理仍发生在后端(OpenAI 或 OpenAI-compatible server)。工程上常见的落地方式是把 LangChain 应用包成一个 HTTP 服务(FastAPI 等),并把工具执行封装为内部函数或外部工具服务。

关注点 LangChain 侧要做什么 推理后端要做什么 工具服务要做什么
模型选择 统一模型调用入口、管理提示词与输入结构 提供 OpenAI-compatible API 或云端 API 无
工具调用 定义工具 schema、把工具结果回注入上下文 产出 tool call 请求(函数名 + JSON 参数) 执行工具、返回结构化结果
部署 将 pipeline 打包为服务,接入鉴权/限流/日志 承载并发与延迟 SLA 治理权限与审计,控制外部副作用
LangGraph

LangGraph 的定位更靠近“可控的状态机/图执行引擎”:它擅长表达长运行、可恢复、可插入人类审核的工作流。与只把 prompt 拼起来相比,它把“循环、分支、检查点、恢复与人机介入”变成一等公民。

Shell
1
pip install -U langgraph langgraph-checkpoint-sqlite

在工程实践中,LangGraph 更适合成为编排运行时本体:把 agent 的状态设计为显式结构(state),把工具调用、模型调用、审批点等写成节点(node),并通过 checkpointer 把状态落盘,从而支持重启恢复与长时间运行。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langgraph.checkpoint.sqlite import SqliteSaver
 
class State(TypedDict):
    text: str
 
def step(state: State) -> dict:
    return {"text": state["text"] + " -> next"}
 
builder = StateGraph(State)
builder.add_node("step", step)
builder.set_entry_point("step")
 
with SqliteSaver.from_conn_string("checkpoints.sqlite") as checkpointer:
    graph = builder.compile(checkpointer=checkpointer)
    out = graph.invoke(
        {"text": "start"},
        config={"configurable": {"thread_id": "demo-thread-1"}},
    )
    print(out["text"])

部署上,LangGraph 常见两种形态:

  • 单体服务:应用进程内执行图,工具执行也在同进程或同机。
  • 分布式工具:图在编排服务内执行,工具通过 HTTP 或 MCP 调用外部服务,工具结果写入状态。
LlamaIndex

LlamaIndex 以“数据代理(Data Agent)”和检索增强为中心,更适合把外部知识、索引、向量库能力组织成 agent 可调用工具。在工具规模变大时,它提供“工具检索(tool retrieval)”这类机制,避免把大量函数定义塞进单次 prompt。

Shell
1
pip install -U llama-index

LlamaIndex 的工具抽象强调把函数与查询引擎包装成可检索、可调用的 Tool。工程上通常把它放在“工具侧”或“检索侧”:编排运行时(LangGraph / Agents SDK)调用 LlamaIndex 的查询/工具,再把结果回注入模型上下文。

Python
1
2
3
4
5
6
7
from llama_index.core.tools import FunctionTool
 
def get_weather(location: str) -> str:
    """Useful for getting the weather for a given location."""
    return f"{location}: 25C"
 
tool = FunctionTool.from_defaults(get_weather, name="get_weather")
DSPy

DSPy 的定位是“用程序化结构来编写 LLM 应用,并用优化器把程序编译成更有效的提示词或权重配置”。它更适合研发阶段系统性迭代(prompt/模块组合/评估驱动),而不是单纯手写 prompt 字符串。

Shell
1
pip install -U dspy

在工具调用上,DSPy 提供了 Tool 原语与 ReAct 等模式,支持使用底层模型的原生 function calling 能力。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
import dspy
 
def search_web(query: str) -> str:
    return f"Search results for {query}"
 
agent = dspy.ReAct(
    signature="question -> answer",
    tools=[search_web],
    max_iters=5,
)
 
result = agent(question="What's new in vLLM?")
print(result.answer)
AutoGen

AutoGen 把多智能体协作与通信基础设施作为一等公民,强调“runtime 负责消息与生命周期,agent 负责逻辑”。它既可用于研究型多智能体协作,也可作为生产编排底座。其工程价值通常体现在:明确的 agent runtime、组件化的模型与工具实现、以及面向多进程/多机的扩展路径。

Shell
1
pip install -U "autogen-agentchat" "autogen-ext[openai,azure]"
CrewAI

CrewAI 更偏向“角色 + 任务流水线(Flow)”表达,适合把业务流程拆成岗位式分工并固定编排。它的工程落地通常依赖明确的输入输出契约与任务边界,否则会迅速滑向不可控的多轮对话。

Shell
1
pip install -U crewai
协议与托管平台

工具调用协议的关键不在“模型能不能调用工具”,而在“工具定义是否标准化、执行是否隔离、权限是否可审计”。OpenAI function/tool calling 与 MCP 分别覆盖了两条常见路径:前者提供“模型到函数”的结构化参数通道;后者提供“工具/资源/提示词”的标准化服务协议,并允许工具以独立服务器形态部署。

OpenAI-compatible function/tool calling

OpenAI-compatible tool calling 的核心是:用 JSON Schema 定义工具参数,并让模型返回结构化的 tool call(函数名 + JSON 参数)。推理后端只负责生成 tool call;真正的工具执行必须在应用侧完成,并把结果作为后续输入再发回模型。

工具 schema 需要满足“模型可理解、服务端可校验”两类约束。一个可落地的最小约束集合包括:参数类型明确、必填项清晰、默认值可推断、以及禁止额外字段(避免模型塞入无关参数)。

字段 含义 典型示例
name 工具名(函数名) "search_docs"
description 工具描述(用于让模型选择工具) 描述越具体,误调用越少
parameters JSON Schema 参数定义
JavaScript
1
2
3
4
5
6
7
8
9
{
  "type": "object",
  "properties": {
    "query": { "type": "string" },
    "top_k": { "type": "integer", "default": 5 }
  },
  "required": ["query"],
  "additionalProperties": false
}

工程上,工具 schema 需要同时满足两类消费者:模型(用于选择与填参)与服务端(用于校验与执行)。推荐在服务端做强校验(Pydantic/zod/JSON Schema validator),并把校验失败当成工具错误返回给模型进行自修复。

在 API 形态上,不同后端的 tools 字段会有轻微差异(嵌套 {"type":"function","function":{...}} 或扁平 {"type":"function","name":...})。编排层通常在入口做一次归一化,保证内部只处理一种表示。

OpenAI Responses API

Responses API 把“生成 + 工具 + 流式事件”统一成一个接口,支持 function calling、内置工具与 MCP 工具。部署形态上,它更像一个推理后端:你的编排服务负责调用 Responses、接收 tool call 事件、执行工具并把结果回注入下一轮调用。

Shell
1
pip install -U openai

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
from openai import OpenAI
 
client = OpenAI()
 
def get_city_uuid(city: str) -> str:
    return f"{city} ID: 00000000-0000-0000-0000-000000000000"
 
tool_mapping = {"get_city_uuid": get_city_uuid}
 
tools = [
    {
        "type": "function",
        "name": "get_city_uuid",
        "description": "Retrieve the internal ID for a city from the internal database.",
        "parameters": {
            "type": "object",
            "properties": {"city": {"type": "string"}},
            "required": ["city"],
            "additionalProperties": False,
        },
    }
]
 
response = client.responses.create(
    model="gpt-5.5",
    input="What's the internal ID for London?",
    tools=tools,
)
 
followup_items = []
for item in response.output:
    if item.type != "function_call":
        continue
    fn = tool_mapping[item.name]
    args = json.loads(item.arguments)
    tool_output = fn(**args)
    followup_items.append(
        {"type": "function_call_output", "call_id": item.call_id, "output": tool_output}
    )
 
if followup_items:
    response2 = client.responses.create(
        model="gpt-5.5",
        input=followup_items,
        previous_response_id=response.id,
    )
    print(response2.output_text)
Agents SDK

Agents SDK 的工程定位是“当你的应用拥有编排与工具执行权”时,提供标准化的 agent loop、handoff、guardrail、session 与 tracing。它与 LangGraph 的差异在于:前者更偏 SDK 级编排框架并与 OpenAI 生态强绑定,后者更偏通用图式编排运行时。

Shell
1
2
pip install openai-agents
export OPENAI_API_KEY=sk-...

Python
1
2
3
4
5
6
7
8
9
10
from agents import Agent, Runner
 
agent = Agent(
    name="Ops helper",
    instructions="Diagnose errors and suggest concrete fixes.",
    model="gpt-5.5",
)
 
result = Runner.run_sync(agent, "Explain this stacktrace and propose a patch.")
print(result.final_output)
MCP

Model Context Protocol(MCP)是一套标准化协议,用于把工具、资源与提示词以“独立服务器”的方式暴露给 AI 应用。它使用 JSON-RPC 2.0,在 host/client/server 三方模型下进行能力协商与调用。MCP 的价值在于把工具系统做成可组合生态:同一个 MCP server 可以被不同 host 复用,同一个 host 也能接多个 MCP server。

FastMCP

FastMCP 以 Python 类型标注与 docstring 自动生成工具 schema,把“写工具函数”变成“发布 MCP 工具”。它适合把内部服务封装成可调用工具,并以 stdio/HTTP 形态部署。stdio 模式下必须避免向 stdout 写日志,否则会破坏 JSON-RPC 通信。

Shell
1
pip install "mcp[cli]" httpx

Python
1
2
3
4
5
6
7
8
9
10
11
from mcp.server.fastmcp import FastMCP
 
mcp = FastMCP("weather")
 
@mcp.tool
def add(a: int, b: int) -> int:
    \"\"\"Add two numbers.\"\"\"
    return a + b
 
if __name__ == "__main__":
    mcp.run()
FastMCP / MCP Python SDK API 用途 脚本位置
FastMCP(name) 创建 MCP server 实例 server 进程
@mcp.tool 声明工具函数,自动生成 schema server 进程
@mcp.resource 暴露可读取资源(类文件数据) server 进程
@mcp.prompt 暴露可复用 prompt 模板 server 进程
mcp.run() 启动 server(stdio/HTTP transport) server 进程

MCP server 开发调试常用 MCP Inspector:

Shell
1
npx -y @modelcontextprotocol/inspector npx @modelcontextprotocol/server-filesystem /path/to/dir
Workflow 与 Runtime 的边界

编排代码通常由两层构成:workflow 负责表达“步骤与依赖关系”,runtime 负责提供“可恢复、可审计、可扩展”的执行语义。把关键能力下沉到 runtime,可以减少“靠 prompt 记住状态”的不稳定性。

能力 更适合放在 workflow(图/链) 更适合放在 runtime(执行层)
状态 状态结构(state schema)、节点输入输出契约 checkpoint/线程/恢复、版本化与回放
失败处理 哪些步骤允许重试、哪些步骤必须人审 指数退避、幂等、断点续跑、死信队列
工具调用 工具集选择与路由(tool routing / retrieval) 权限、沙箱、审计日志、限流、并发隔离
观测 关键业务 span 的命名与结构化属性 trace/metrics/log 的采集、采样、落库与查询
可观测性与审计

Agent 系统的主要故障点通常落在工具调用链偏航:工具选错、参数不合法、返回值解析失败、以及重试/恢复逻辑异常。可观测性需要覆盖:每次模型调用的输入输出、每次工具调用的参数与返回值、以及每个步骤的耗时与失败原因。Langfuse 在 LangChain / LangGraph 生态中常用作 tracing 平台。

Langfuse(LangChain 回调集成)
Shell
1
2
3
4
pip install -U langfuse langchain langgraph langchain-openai
export LANGFUSE_PUBLIC_KEY=pk-lf-...
export LANGFUSE_SECRET_KEY=sk-lf-...
export LANGFUSE_HOST=https://cloud.langfuse.com

Python
1
2
3
4
5
6
7
8
9
10
11
12
from langfuse.langchain import CallbackHandler
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
 
langfuse_handler = CallbackHandler()
 
llm = ChatOpenAI(model_name="gpt-5.4")
prompt = ChatPromptTemplate.from_template("Tell me a joke about {topic}.")
chain = prompt | llm
 
resp = chain.invoke({"topic": "cats"}, config={"callbacks": [langfuse_handler]})
print(resp.content)
浏览器与工具调用组件

浏览器自动化通常以“工具”的形态接入 agent:编排层提供一个受控接口,例如 open_url、click、type、screenshot、extract_text;底层用 Playwright/Puppeteer 执行真实浏览器操作。生产部署时更常见的做法是把浏览器跑在隔离容器里,通过队列或 RPC 驱动,避免把不可信页面脚本与业务服务混跑在同一进程。

Playwright
Shell
1
2
pip install playwright
python -m playwright install

Python
1
2
3
4
5
6
7
8
from playwright.sync_api import sync_playwright
 
with sync_playwright() as p:
    browser = p.chromium.launch(headless=True)
    page = browser.new_page()
    page.goto("https://playwright.dev")
    title = page.title()
    browser.close()

如果把 Playwright 作为工具服务部署,推荐把“浏览器生命周期管理”显式化:为每个任务创建 context,任务结束后关闭 context,避免跨任务共享 cookie/session 导致串台。

Puppeteer
Shell
1
npm i puppeteer

Puppeteer 与 Puppeteer-core 的选择点在于“是否需要自动下载浏览器”。当你连接远程浏览器或自行管理浏览器镜像时,通常使用 puppeteer-core 并关闭下载。

从训练脚本到推理服务的最小闭环

工程闭环的目标是:用一套可复现的目录约定与脚本接口,把“数据准备 → 训练/微调 → checkpoint 管理 → 导出 → 启动推理服务 → 客户端调用 → 监控 → 回滚”串成一条可持续迭代的流水线。这个闭环必须满足两点:一是产物可追溯(可定位到数据版本、代码版本、超参版本),二是可回滚(任何上线问题都能在分钟级回退到上一版本)。

目录与产物约定(可回滚的最小形态)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
repo/
  data/
    raw/                       # 原始数据(不直接喂训练)
    processed/
      train.jsonl              # 训练集(SFT / 分类 / NER 等)
      eval.jsonl               # 验证集
      dataset_meta.json        # 数据摘要:hash、样本数、字段说明、生成脚本参数
  scripts/
    prepare_data.py            # raw -> processed,输出 dataset_meta.json
  train/
    train_sft.py               # 训练脚本(支持断点续训、保存 best/last)
  export/
    export_merge_lora.py       # 可选:LoRA 合并导出为“纯模型”目录
  outputs/
    runs/
      2026-05-09_210530_sft/   # 单次训练 run(可追溯、不可变)
        checkpoints/           # checkpoint-xxx
        best/                  # 指向 best checkpoint 或其导出物
        logs/                  # tensorboard / jsonl / wandb(任选其一)
        run_meta.json          # 代码版本、数据 hash、超参、环境信息
  models/
    registry/
      model_v0001/             # 可部署模型目录(merge 后或 base+adapter 信息)
      model_v0002/
    prod -> registry/model_v0002   # 生产指针(原子替换实现回滚)
  serving/
    vllm/
      serve.sh                 # 启动推理服务(读 models/prod)
      healthcheck.sh
  clients/
    call_openai.py
数据准备(raw → jsonl)

最小可维护做法是把训练数据固化成 jsonl,并明确字段语义。SFT 场景建议至少包含 prompt 与 response(或统一成 text,让样本是完整的对话模板)。验证集必须在训练前固定,避免“验证泄漏”造成误判。

prepare_data.py(最小骨架)
scripts/prepare_data.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import hashlib
import json
from pathlib import Path
 
def sha256_file(path: Path) -> str:
  h = hashlib.sha256()
  with path.open("rb") as f:
    for chunk in iter(lambda: f.read(1024 * 1024), b""):
      h.update(chunk)
  return h.hexdigest()
 
def write_jsonl(rows, out_path: Path) -> None:
  out_path.parent.mkdir(parents=True, exist_ok=True)
  with out_path.open("w", encoding="utf-8") as f:
    for r in rows:
      f.write(json.dumps(r, ensure_ascii=False) + "\n")
 
def main():
  raw_path = Path("data/raw/raw.json")  # 例:你的上游导出
  raw = json.loads(raw_path.read_text(encoding="utf-8"))
 
  # 将原始数据规范化为 prompt/response(保持幂等)
  rows = []
  for x in raw:
    rows.append({
      "id": x["id"],
      "prompt": x["prompt"].strip(),
      "response": x["response"].strip(),
    })
 
  # 固定切分(可改为按 doc_id/时间分桶等更稳健策略)
  n = len(rows)
  train, eval_ = rows[: int(n * 0.98)], rows[int(n * 0.98):]
 
  out_train = Path("data/processed/train.jsonl")
  out_eval = Path("data/processed/eval.jsonl")
  write_jsonl(train, out_train)
  write_jsonl(eval_, out_eval)
 
  meta = {
    "raw_path": str(raw_path),
    "raw_sha256": sha256_file(raw_path),
    "train_path": str(out_train),
    "eval_path": str(out_eval),
    "train_size": len(train),
    "eval_size": len(eval_),
    "schema": {"id": "str", "prompt": "str", "response": "str"},
  }
  Path("data/processed/dataset_meta.json").write_text(
    json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
  )
 
if __name__ == "__main__":
  main()
训练与 checkpoint(可复现、可恢复)

训练脚本的最小要求:固定输入数据与超参、支持断点续训、把 checkpoint 与元信息写入 run 目录,并在训练结束后产出一个“可部署入口”(best checkpoint 或导出的模型目录)。下面示例以 TRL 的 SFTTrainer + PEFT(LoRA)为主线。

依赖安装(训练侧)
Shell
1
pip install -U transformers accelerate datasets trl peft safetensors
train_sft.py(最小骨架:SFT + LoRA)
train/train_sft.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
 
from datasets import load_dataset
from transformers import AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
 
def git_head() -> str:
  try:
    return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
  except Exception:
    return "unknown"
 
def main():
  run_id = datetime.now().strftime("%Y-%m-%d_%H%M%S_sft")
  run_dir = Path("outputs/runs") / run_id
  ckpt_dir = run_dir / "checkpoints"
  run_dir.mkdir(parents=True, exist_ok=True)
  ckpt_dir.mkdir(parents=True, exist_ok=True)
 
  base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
  train_path = "data/processed/train.jsonl"
  eval_path = "data/processed/eval.jsonl"
 
  tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
  if tok.pad_token is None:
    tok.pad_token = tok.eos_token
 
  ds_train = load_dataset("json", data_files=train_path, split="train")
  ds_eval = load_dataset("json", data_files=eval_path, split="train")
 
  def to_text(batch):
    text = []
    for p, r in zip(batch["prompt"], batch["response"]):
      text.append(f"### Instruction\n{p}\n\n### Response\n{r}")
    return {"text": text}
 
  ds_train = ds_train.map(to_text, batched=True, remove_columns=ds_train.column_names)
  ds_eval = ds_eval.map(to_text, batched=True, remove_columns=ds_eval.column_names)
 
  args = SFTConfig(
    output_dir=str(ckpt_dir),
    max_length=2048,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    bf16=True,
    report_to="none",
  )
 
  peft_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type="CAUSAL_LM",
  )
 
  trainer = SFTTrainer(
    model=base_model,                 # TRL 支持传入模型 id
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    dataset_text_field="text",
    tokenizer=tok,
    peft_config=peft_cfg,
  )
 
  trainer.train()
 
  # best checkpoint 路径由 TrainerState 给出;作为“可部署入口”的候选
  best = getattr(trainer.state, "best_model_checkpoint", None) or str(ckpt_dir)
 
  # 记录 run 元信息(用于追溯与回滚判因)
  meta = {
    "run_id": run_id,
    "base_model": base_model,
    "git_head": git_head(),
    "train_path": train_path,
    "eval_path": eval_path,
    "best_checkpoint": best,
    "training_args": args.to_dict() if hasattr(args, "to_dict") else vars(args),
  }
  (run_dir / "run_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
 
  # 产出一个 best 指针(便于下游导出/上线脚本读取)
  (run_dir / "best").write_text(best, encoding="utf-8")
 
if __name__ == "__main__":
  main()

多卡启动(最小形态)
Shell
1
accelerate launch train/train_sft.py
断点续训与 checkpoint 策略

断点续训的最小实现是:训练启动时检测 output_dir 下最近的 checkpoint,并把它作为 resume_from_checkpoint 输入。线上训练任务应固定 save_total_limit,避免磁盘被历史 checkpoint 填满导致任务失败。

导出与上线包(从 checkpoint 到“可部署模型目录”)

PEFT/LoRA 训练常见有两种上线形态:

  • 形态 A:部署 base + LoRA adapter(推理侧按请求或按版本加载 adapter),上线快、存储小。
  • 形态 B:把 LoRA 合并到 base(导出为纯模型目录),推理侧只加载一个目录,上线更简单。
export_merge_lora.py(形态 B:合并导出)
export/export_merge_lora.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
from pathlib import Path
 
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
 
def main():
  ap = argparse.ArgumentParser()
  ap.add_argument("--base_model", required=True)
  ap.add_argument("--adapter_dir", required=True)   # trainer 产出的 adapter checkpoint
  ap.add_argument("--out_dir", required=True)       # models/registry/model_vXXXX
  args = ap.parse_args()
 
  out_dir = Path(args.out_dir)
  out_dir.mkdir(parents=True, exist_ok=True)
 
  tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
  base = AutoModelForCausalLM.from_pretrained(
    args.base_model,
    torch_dtype="auto",
    device_map="cpu",
  )
  model = PeftModel.from_pretrained(base, args.adapter_dir)
  model = model.merge_and_unload()
 
  # safe_serialization=True -> safetensors
  model.save_pretrained(out_dir, safe_serialization=True)
  tok.save_pretrained(out_dir)
 
if __name__ == "__main__":
  main()
版本化与可回滚指针

上线包放在 models/registry/model_vXXXX,生产指针 models/prod 是一个符号链接。切换版本通过“原子替换 symlink”实现回滚。

上线/回滚(示例)
Shell
1
2
3
4
5
# 上线:切到新版本
ln -sfn "$(pwd)/models/registry/model_v0002" "$(pwd)/models/prod"
 
# 回滚:切回旧版本
ln -sfn "$(pwd)/models/registry/model_v0001" "$(pwd)/models/prod"
启动推理服务(vLLM OpenAI 兼容服务)

推理服务侧的最小目标是提供稳定的 HTTP API,并将“模型路径/版本”从代码里剥离出来(通过 models/prod 指针决定)。下面示例使用 vLLM 的 OpenAI-Compatible Server。

安装(推理侧)
Shell
1
pip install -U vllm
serve.sh(最小启动脚本)
serving/vllm/serve.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
#!/usr/bin/env bash
set -euo pipefail
 
MODEL_DIR="$(cd "$(dirname "$0")/../.." && pwd)/models/prod"
 
export VLLM_LOGGING_LEVEL=INFO
 
vllm serve "$MODEL_DIR" \
  --host 0.0.0.0 \
  --port 8000 \
  --dtype auto \
  --served-model-name prod \
  --api-key "token-abc123"
健康检查与 smoke test
serving/vllm/healthcheck.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#!/usr/bin/env bash
set -euo pipefail
 
curl -sf http://127.0.0.1:8000/v1/models > /dev/null
 
# 简单生成测试(Chat Completions)
curl -sf http://127.0.0.1:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer token-abc123" \
  -d '{
    "model": "prod",
    "messages": [{"role": "user", "content": "Say hi in one sentence."}],
    "temperature": 0
  }' > /dev/null
客户端调用(OpenAI SDK 对接)

OpenAI 兼容服务的价值是客户端可复用:同一套调用代码既能访问云端,也能访问本地/自建的 vLLM 服务。下面示例用 OpenAI Python SDK 走 Chat Completions。

Shell
1
pip install -U openai

clients/call_openai.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from openai import OpenAI
 
client = OpenAI(
  base_url="http://127.0.0.1:8000/v1",
  api_key="token-abc123",
)
 
resp = client.chat.completions.create(
  model="prod",
  messages=[{"role": "user", "content": "Write a haiku about debugging."}],
  temperature=0,
  max_tokens=128,
)
 
print(resp.choices[0].message.content)
监控与回滚(最小可操作)

闭环的监控重点放在三类信号:服务可用性(health)、吞吐与延迟(QPS/TTFT/TPOT)、以及错误率(5xx/超时/OOM)。最小回滚流程必须是“切换模型指针 + 重启服务 + 运行 smoke test”。

指标采集点(建议最小集)
  • 服务级别: /v1/models 可用性,5xx 比例,请求超时比例。
  • 推理级别:首 token 延迟(TTFT)、每 token 延迟(TPOT)、生成长度分布。
  • 资源级别:GPU 显存占用、GPU utilization、CPU/内存、队列长度。
vLLM Prometheus metrics
Shell
1
curl -sf http://127.0.0.1:8000/metrics | head
回滚脚本(最小骨架)
serving/vllm/rollback.sh
Shell
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#!/usr/bin/env bash
set -euo pipefail
 
TARGET="${1:?usage: rollback.sh model_vXXXX}"
ROOT="$(cd "$(dirname "$0")/../.." && pwd)"
 
ln -sfn "$ROOT/models/registry/$TARGET" "$ROOT/models/prod"
 
# 具体重启方式取决于你的进程管理器(systemd/docker/k8s)
# 这里仅给出最小形态:杀进程后重启
pkill -f "vllm serve" || true
nohup bash "$ROOT/serving/vllm/serve.sh" > "$ROOT/outputs/vllm_stdout.log" 2>&1 &
 
bash "$ROOT/serving/vllm/healthcheck.sh"
PyTorch 详解

PyTorch 是训练与推理编程栈的“最底层可控面”:Tensor、device、autograd、 nn.Module、数据加载、序列化、编译与分布式训练都在这一层完成。上层框架可以隐藏细节,但当你需要排查显存、吞吐、梯度同步、checkpoint 恢复、算子不确定性时,最终必须回到 PyTorch 的对象模型与 API 语义。

安装矩阵与快速验证

PyTorch 的安装需要同时匹配三件事:Python 版本、操作系统、计算平台(CPU/CUDA/ROCm/MPS)。实际工程里最稳妥的策略是:用官方安装页的选择器生成命令,然后将该命令固化到你的环境脚本或镜像构建中。

目标平台 常用安装方式(示例) 工程备注
CPU(Linux/macOS/Windows)
Shell
1
pip install -U torch torchvision
CPU-only 适合开发与单测;性能调优与显存问题需要在目标 GPU 上复现。
CUDA(NVIDIA GPU)
Shell
1
2
# 以官方安装页生成的命令为准;典型形态如下
pip install -U torch torchvision --index-url https://download.pytorch.org/whl/cu126
CUDA wheel 与机器驱动/运行时要匹配;多机训练应当在镜像层固定 CUDA 与 PyTorch 组合。
ROCm(AMD GPU)
Shell
1
2
# 以官方安装页生成的命令为准(ROCm 版本需匹配系统栈)
pip install -U torch torchvision --index-url https://download.pytorch.org/whl/rocm
ROCm 生态对内核/驱动版本更敏感,建议使用官方/社区维护的容器基镜像。
MPS(Apple Silicon)
Shell
1
pip install -U torch torchvision
设备为 mps;算子覆盖度与性能特征与 CUDA 不同。

最小验证覆盖三个断言:版本可读、Tensor 可算、目标加速器可见。

verify_torch.py
Python
1
2
3
4
5
6
7
8
9
import torch
 
print("torch:", torch.__version__)
print("cuda_available:", torch.cuda.is_available())
print("mps_available:", hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
 
x = torch.randn(2, 3)
y = x @ x.T
print("ok:", y.shape)
Tensor、dtype 与 device

Tensor 的关键元信息是:形状(shape)、数值类型(dtype)、设备(device)、以及是否参与梯度(requires_grad)。训练代码里常见 bug 本质都是“不匹配”:输入和参数不在同一 device、label dtype 错、view/reshape 造成非 contiguous 导致算子退化,或无意间把需要梯度的张量带入无梯度区间。

创建、迁移与布局
API 用途 可直接复用的写法
torch.tensor 从 Python 对象创建张量(会拷贝)
Python
1
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
torch.as_tensor 尽量不拷贝地包装已有数据
Python
1
2
3
import numpy as np
arr = np.zeros((2, 3), dtype=np.float32)
x = torch.as_tensor(arr)  # 可能与 numpy 共享内存
torch.from_numpy 从 numpy 创建(共享内存)
Python
1
x = torch.from_numpy(arr)  # 修改一侧会影响另一侧
Tensor.to 迁移 device/dtype(训练最常用)
Python
1
2
device = torch.device("cuda", 0)  # or "cpu"/"mps"
x = x.to(device=device, dtype=torch.bfloat16)
Tensor.contiguous 把非连续内存布局变成连续
Python
1
2
x = x.permute(0, 2, 1)     # 可能变成非 contiguous
x = x.contiguous()        # 需要时显式转回
多设备下的 device 选择

单机多卡训练通常按“每进程绑定一张卡”的方式组织。绑定的核心动作是:在进程启动后立即 torch.cuda.set_device(local_rank),并确保模型与 batch 都迁移到 cuda:local_rank。

Python
1
2
3
4
5
6
7
8
9
import os
import torch
 
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
 
model = model.to(device)
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
autograd:梯度模式与反向图

PyTorch 的 autograd 是胶带式自动求导:前向执行时记录算子与中间结果,反向时从标量 loss 回传梯度到叶子张量(leaf tensors)。工程上需要明确三类“梯度模式”:训练(需要梯度)、评估(无梯度)、推理(更强的 inference mode)。

训练:backward 与清梯度

一个稳定的训练 step 通常遵循固定模板:清梯度、前向、算 loss、反向、(可选)裁剪、优化器 step。清梯度推荐使用 set_to_none=True,这会让 PyTorch 用 None 表示“没有梯度”,减少写零开销。

Python
1
2
3
4
5
optimizer.zero_grad(set_to_none=True)
loss = model(**batch).loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
评估与推理:no_grad 与 inference_mode

model.eval() 只切换模块行为(例如 dropout/batchnorm),不影响 autograd。关闭梯度需要显式进入无梯度上下文:

  • torch.no_grad():关闭反向图记录,适用于评估。
  • torch.inference_mode():更激进的推理模式,额外禁用若干 autograd 相关开销;它不会自动调用 model.eval()。
Python
1
2
3
model.eval()
with torch.inference_mode():
    logits = model(x)
autograd.grad:只要梯度、不做参数更新

torch.autograd.grad 在实现自定义优化、梯度惩罚、或需要显式控制梯度张量生命周期时更直接。

Python
1
2
3
4
5
6
import torch
 
x = torch.randn(4, requires_grad=True)
y = (x ** 2).sum()
gx, = torch.autograd.grad(y, x, create_graph=False)
print(gx)
nn.Module、参数注册与 state_dict

nn.Module 提供两类关键能力:组织子模块并注册参数/缓冲区;提供可序列化的 state_dict,用于保存/恢复训练状态和做 warmstart。

参数(Parameter)与缓冲区(Buffer)

可训练权重应当是 nn.Parameter 或由标准层(Linear/Conv/Embedding 等)创建;非训练但需要随模型保存的状态(例如 batchnorm 的 running_mean)应注册为 buffer。

Python
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn
 
class Toy(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(16, 16)
        self.register_buffer("scale", torch.tensor(1.0), persistent=True)
 
    def forward(self, x):
        return self.proj(x) * self.scale

buffer 是否进入 state_dict 由 persistent 决定:非持久 buffer 不会被保存,这常用于缓存中间结果或仅运行期有效的状态。

state_dict:保存/加载的工程契约

state_dict() 返回一个 Python dict,包含参数与持久化 buffer。加载时常用两种策略:

  • 严格恢复:结构完全一致,使用默认 strict=True。
  • warmstart:允许缺键/多键,使用 strict=False,并显式检查 missing/unexpected keys。
Python
1
2
3
4
state = torch.load("model.pt", map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(state, strict=False)
print("missing:", missing)
print("unexpected:", unexpected)
常用 Module API 速查
API 用途 要点
model.train() / model.eval() 切换训练/评估模式 只影响模块行为;不等价于启用/关闭梯度。
model.parameters() 优化器参数源 多参数组(weight_decay、lr)通常从这里拆分。
model.buffers() 迭代 buffer 排查 batchnorm 统计量、EMA、缓存状态常用。
model.state_dict() 提取可保存状态 推荐保存 state_dict,而不是直接保存整个 Module 对象。
model.load_state_dict(...) 恢复参数与 buffer 返回 missing/unexpected keys;warmstart 必须打印/检查。
数据加载:Dataset / DataLoader

训练吞吐的瓶颈经常不在 GPU,而在数据管线:解码、tokenize、增强、CPU 到 GPU 拷贝、以及 DataLoader 的多进程调度。DataLoader 的可调参数很多,但最关键的是:Dataset 类型(map-style/iterable-style)、worker 并发、pin memory、以及 batch 组装(collate)。

Map-style vs Iterable-style
map_style_dataset.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, items):
        self.items = items
 
    def __len__(self):
        return len(self.items)
 
    def __getitem__(self, idx):
        x, y = self.items[idx]
        return {"x": x, "y": y}

iterable_dataset.py
Python
1
2
3
4
5
6
7
from torch.utils.data import IterableDataset
 
class StreamDataset(IterableDataset):
    def __iter__(self):
        # 适合流式数据:数据库、消息队列、远端对象存储分片
        for i in range(1000000):
            yield {"x": i}
DataLoader 参数(工程常用)

DataLoader 的构造参数是你调吞吐的第一现场: num_workers 决定 CPU 并发、 pin_memory + non_blocking=True 影响 H2D 拷贝、 prefetch_factor / persistent_workers 影响 worker 生命周期与预取深度。

Python
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import DataLoader
 
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
    drop_last=True,
)

将 batch 移动到 GPU 时,pin memory 结合 non_blocking=True 才能发挥异步拷贝效果。

Python
1
2
def to_device(batch, device):
    return {k: v.to(device, non_blocking=True) for k, v in batch.items()}
分布式训练的数据切分(DistributedSampler)

DDP 下每个进程应读取不同数据子集。map-style 数据集通常配合 DistributedSampler,并在每个 epoch 调用 set_epoch 让 shuffle 可复现。

Python
1
2
3
4
5
6
7
8
9
from torch.utils.data.distributed import DistributedSampler
 
sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=8, pin_memory=True)
 
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)
    for batch in loader:
        ...
AMP:混合精度的工程写法

混合精度训练通常用 torch.amp.autocast 与 torch.amp.GradScaler 组合。旧的 torch.cuda.amp.autocast / torch.cpu.amp.autocast 已逐步迁移到统一入口。

amp_step.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
 
device = torch.device("cuda", 0)
scaler = torch.amp.GradScaler("cuda")
 
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
 
for batch in loader:
    batch = to_device(batch, device)
 
    optimizer.zero_grad(set_to_none=True)
    with torch.amp.autocast("cuda", dtype=torch.bfloat16):
        loss = model(**batch).loss
 
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
Checkpoint:保存、恢复与安全加载

训练脚本里 checkpoint 的工程目标是两件事:可恢复(resume 时学习率/AMP/随机性都对齐),以及可复用(用于推理或 warmstart)。推荐把 checkpoint 组织成一个 dict:模型 state、优化器 state、调度器 state、AMP scaler state、当前步数/epoch、以及必要的 RNG 状态。

通用 checkpoint 结构
checkpoint_io.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
 
def save_checkpoint(path, *, model, optimizer, scheduler=None, scaler=None, step=0, epoch=0):
    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler else None,
        "scaler": scaler.state_dict() if scaler else None,
        "step": int(step),
        "epoch": int(epoch),
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    }
 
    tmp = path + ".tmp"
    torch.save(ckpt, tmp)
    os.replace(tmp, path)  # 原子替换,避免写到一半崩溃留下坏文件
 
def load_checkpoint(path, *, model, optimizer, scheduler=None, scaler=None, map_location="cpu"):
    ckpt = torch.load(path, map_location=map_location, weights_only=True)
    model.load_state_dict(ckpt["model"], strict=True)
    optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler and ckpt.get("scheduler"):
        scheduler.load_state_dict(ckpt["scheduler"])
    if scaler and ckpt.get("scaler"):
        scaler.load_state_dict(ckpt["scaler"])
 
    step = int(ckpt.get("step", 0))
    epoch = int(ckpt.get("epoch", 0))
    return step, epoch
torch.load 的安全边界(weights_only)

torch.load 基于 pickle 反序列化,不能加载不可信来源的文件。加载权重/状态字典时优先使用 weights_only=True,把反序列化限定在 state_dict 等常见安全类型集合内。

torch.compile:编译加速与排错入口

torch.compile 会追踪(trace)你的 Python 代码中的张量计算并生成可优化的图。工程上它的常见收益来自两类:更少的 Python 开销、以及 Inductor 等后端生成的融合 kernel。无法追踪的代码会产生 graph break,这通常是性能损失,而不是静默错误。

compile_minimal.py
Python
1
2
3
4
5
6
7
import torch
 
model = model.to("cuda")
model = torch.compile(model)  # 最小改动:只包一次
 
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
    out = model(x)

当你需要确认 compile 到底 trace 了什么,可以打开日志来观察 traced graph(用于定位 graph break 与非预期的 Python 分支)。

Python
1
2
import torch
torch._logging.set_logs(graph_code=True)
分布式训练:torchrun + DDP 最小可用形态

DDP 的基本形态是“每进程一份模型副本 + 反向时梯度同步”。启动建议使用 torchrun,它会为每个进程设置 RANK/ LOCAL_RANK/ WORLD_SIZE 等环境变量,并负责 rendezvous。

启动命令(单机多卡)
Shell
1
torchrun --standalone --nproc_per_node=8 train_ddp.py --config config.yaml
DDP 训练脚本骨架(可直接复用)
train_ddp.py
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
 
def ddp_setup():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank
 
def is_rank0():
    return int(os.environ.get("RANK", "0")) == 0
 
def main():
    local_rank = ddp_setup()
    device = torch.device("cuda", local_rank)
 
    model = MyModel(...).to(device)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=True)
 
    dataset = MyDataset(...)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
    loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=8, pin_memory=True)
 
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.amp.GradScaler("cuda")
 
    for epoch in range(10):
        sampler.set_epoch(epoch)
        model.train()
        for batch in loader:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
 
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                loss = model(**batch).loss
 
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
 
        if is_rank0():
            torch.save(model.module.state_dict(), f"model-ep{epoch}.pt")
 
    dist.destroy_process_group()
 
if __name__ == "__main__":
    main()
脚本组织方式(训练/推理共用)

可维护性来自“把易变部分隔离出来”:模型定义、数据定义、运行时策略(AMP/compile/DDP)、以及 I/O(checkpoint/logging)。一个简单但可扩展的组织方式如下:

1
2
3
4
5
6
7
8
9
project/
  src/
    models.py        # nn.Module 定义与构建函数
    data.py          # Dataset/Tokenizer/Collate
    train_step.py    # 单步训练逻辑(支持 AMP/compile)
    ddp.py           # 分布式初始化与 rank 工具函数
    ckpt.py          # save/load_checkpoint(含 weights_only 策略)
  train.py           # 单机/单卡入口
  train_ddp.py       # torchrun 入口
Transformers 详解

Transformers 在工程上提供了一套可组合的入口:模型与 tokenizer/processor 的加载与保存( from_pretrained/ save_pretrained)、架构无关的 Auto* 工厂、训练循环( Trainer/ TrainingArguments)、以及推理生成( generate)与对话模板(Chat Template)。这一节只讲“如何编程接入与部署落地”,不展开算法原理。

安装与最小依赖
Shell
1
2
3
4
pip install -U transformers
 
# 需要 device_map="auto" / offload / 分布式等能力时通常还需要
pip install -U accelerate
from_pretrained / save_pretrained(装载与交付)

from_pretrained 负责把“模型仓库或本地目录”解析成 Python 对象;save_pretrained 把对象序列化回一个可复用的目录。工程上把这个目录当作“可交付模型包”(artifact),它应当可被推理服务直接加载。

从 Hub 或本地目录加载
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
 
model_id_or_path = "Qwen/Qwen3-0.6B"   # 也可以是 ./models/prod 这类本地目录
 
tok = AutoTokenizer.from_pretrained(model_id_or_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
  model_id_or_path,
  torch_dtype="auto",
  device_map="auto",          # 需要 accelerate
)
 
model.eval()
with torch.inference_mode():
  out = model.generate(**tok("Hello", return_tensors="pt").to(model.device), max_new_tokens=32)
  print(tok.decode(out[0], skip_special_tokens=True))
保存到本地目录(模型包)
Python
1
2
3
4
5
6
7
from pathlib import Path
 
out_dir = Path("models/registry/model_v0001")
out_dir.mkdir(parents=True, exist_ok=True)
 
model.save_pretrained(out_dir, safe_serialization=True)  # 推荐 safetensors
tok.save_pretrained(out_dir)
本地权重目录结构(读写约定)

Transformers 的加载逻辑依赖“目录里有哪些标准文件”。同一个目录既可以来自 Hub 下载缓存,也可以来自 save_pretrained 导出。

1
2
3
4
5
6
7
8
9
10
11
model_dir/
  config.json
  generation_config.json                # 可选:生成参数默认值
  model.safetensors                     # 或 pytorch_model.bin
  model.safetensors.index.json          # 可选:分片索引(大模型常见)
  model-00001-of-00002.safetensors      # 可选:分片权重文件
  tokenizer.json                        # fast tokenizer 常见
  tokenizer_config.json
  special_tokens_map.json
  vocab.json / merges.txt               # BPE 类 tokenizer 常见
  spiece.model                          # SentencePiece tokenizer 常见
离线加载与缓存目录

离线/内网环境的最小做法是提前把模型仓库下载到本地目录,然后用本地路径调用 from_pretrained。需要让缓存落到指定盘符时,优先设置 Hugging Face Hub 的缓存环境变量(例如 HF_HUB_CACHE / HF_HOME)。

示例:改变缓存目录
Shell
1
2
export HF_HOME=/data/hf
export HF_HUB_CACHE=/data/hf/hub

只用本地文件(不触网)
Python
1
2
tok = AutoTokenizer.from_pretrained("./model_dir", local_files_only=True)
model = AutoModelForCausalLM.from_pretrained("./model_dir", local_files_only=True)
Auto* 家族(统一入口)

Auto* 是“按配置自动选择具体实现”的工厂。工程上把它当作跨架构的稳定入口:你不需要在代码里硬编码某个模型类名,尤其是在需要频繁替换基座模型时。

Auto 类 典型用途 最小用法(示例)
AutoConfig 读取/改写模型配置(层数、rope、token id 等)
Python
1
2
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
AutoTokenizer 加载 tokenizer(文本 → input_ids/attention_mask)
Python
1
2
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", use_fast=True)
AutoProcessor 多模态 processor(文本+图像/音频等统一预处理)
Python
1
2
from transformers import AutoProcessor
proc = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
AutoModel 只要 backbone 表示(不带任务头)
Python
1
2
from transformers import AutoModel
m = AutoModel.from_pretrained("bert-base-uncased")
AutoModelForCausalLM Decoder-only 生成(LLM 推理/微调)
Python
1
2
from transformers import AutoModelForCausalLM
m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
AutoModelForSeq2SeqLM Encoder-Decoder 生成(翻译、摘要等)
Python
1
2
from transformers import AutoModelForSeq2SeqLM
m = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
AutoModelForSequenceClassification 文本分类
Python
1
2
from transformers import AutoModelForSequenceClassification
m = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
AutoModelForTokenClassification 序列标注(NER/词性标注等)
Python
1
2
from transformers import AutoModelForTokenClassification
m = AutoModelForTokenClassification.from_pretrained("bert-base-uncased", num_labels=9)
Tokenizer 与 Processor(输入标准化)

Tokenizer/Processor 把“原始输入”变成模型可消费的张量字典。文本模型通常用 tokenizer;视觉/语音/多模态模型往往用 processor,它内部可能组合 tokenizer + image/audio processor。

Tokenizer 的返回结构
Python
1
2
3
4
5
6
7
8
inputs = tok(
  ["a", "b"],
  padding=True,
  truncation=True,
  max_length=128,
  return_tensors="pt",
)
# inputs 通常包含:input_ids, attention_mask(以及 token_type_ids 等,视模型而定)
Processor 的典型用法(以 CLIP 为例)
Python
1
2
3
4
5
6
7
8
9
10
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
 
proc = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
 
img = Image.open(requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/cat.jpg", stream=True).raw)
inputs = proc(text=["a photo of a cat"], images=[img], return_tensors="pt", padding=True)
out = model(**inputs)
Trainer / TrainingArguments(训练循环)

Trainer 把训练循环、评估、保存 checkpoint、日志与分布式协同做成统一入口。工程上最关键的是把 TrainingArguments 固化成可追溯的配置(写入 run_meta.json 或随 checkpoint 一起存档),并严格区分 “best checkpoint” 与 “last checkpoint”。

最小训练骨架(分类任务)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
from datasets import load_dataset
from transformers import (
  AutoTokenizer,
  AutoModelForSequenceClassification,
  DataCollatorWithPadding,
  Trainer,
  TrainingArguments,
)
 
ds = load_dataset("glue", "sst2")
tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
 
def tokenize(batch):
  return tok(batch["sentence"], truncation=True)
 
ds = ds.map(tokenize, batched=True)
collator = DataCollatorWithPadding(tokenizer=tok)
 
args = TrainingArguments(
  output_dir="out_sst2",
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  num_train_epochs=1,
  evaluation_strategy="steps",
  eval_steps=200,
  save_strategy="steps",
  save_steps=200,
  save_total_limit=3,
  load_best_model_at_end=True,
  metric_for_best_model="eval_loss",
  greater_is_better=False,
  report_to="none",
)
 
def compute_metrics(eval_pred):
  logits, labels = eval_pred
  preds = np.argmax(logits, axis=-1)
  return {"acc": (preds == labels).mean().item()}
 
trainer = Trainer(
  model=model,
  args=args,
  train_dataset=ds["train"],
  eval_dataset=ds["validation"],
  tokenizer=tok,
  data_collator=collator,
  compute_metrics=compute_metrics,
)
trainer.train()
断点续训与导出
Python
1
2
3
4
5
6
# 断点续训:resume_from_checkpoint 可以传具体 checkpoint 路径
trainer.train(resume_from_checkpoint=True)
 
# 导出最终模型包(建议使用 best checkpoint 对应的权重)
trainer.save_model("models/registry/model_v0001")
tok.save_pretrained("models/registry/model_v0001")
generate 与 GenerationConfig(推理生成)

generate 把“下一 token 分布 → 序列”的解码策略(贪心、beam、采样等)封装成统一入口。工程上建议把生成策略固化为 GenerationConfig(或写入服务端配置),避免在业务代码里散落大量参数。

最小生成示例
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype="auto", device_map="auto")
 
gen_cfg = GenerationConfig(
  max_new_tokens=128,
  do_sample=True,
  temperature=0.7,
  top_p=0.9,
)
 
inputs = tok("Explain KV cache in one paragraph.", return_tensors="pt").to(model.device)
with torch.inference_mode():
  out = model.generate(**inputs, generation_config=gen_cfg)
print(tok.decode(out[0], skip_special_tokens=True))
常见参数与含义(部署侧最常用)
参数 作用 工程建议
max_new_tokens 限制生成 token 数 优先用它而不是 max_length(后者包含 prompt token)。
do_sample 采样开关 需要稳定输出时关闭采样,并把 temperature=0 或直接不用 temperature。
temperature / top_p 采样随机性与截断 线上服务通常把它们做成可配置策略,按业务风险控制随机性。
eos_token_id / pad_token_id 结束与 padding 的 token id Decoder-only 模型常需要显式设置 pad_token(一般等于 eos_token)。
Chat Template(对话模板接入)

Chat Template 把“messages 列表”转换成模型需要的 prompt 格式,并确保 role token、分隔符与结束符一致。工程上建议统一通过 apply_chat_template 生成输入,避免手工拼 prompt 导致格式漂移。

apply_chat_template + generate(最小骨架)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
 
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype="auto", device_map="auto")
 
messages = [
  {"role": "system", "content": "You are a precise assistant."},
  {"role": "user", "content": "Summarize what gradient accumulation is."},
]
 
input_ids = tok.apply_chat_template(
  messages,
  add_generation_prompt=True,
  return_tensors="pt",
).to(model.device)
 
with torch.inference_mode():
  out = model.generate(input_ids, max_new_tokens=128)
print(tok.decode(out[0], skip_special_tokens=True))
训练数据的 chat template 对齐

如果模型的 tokenizer 自带 chat template,SFT 数据建议按同一模板构造训练样本;否则推理时的对话格式会与训练时不一致,表现为“角色混淆”“结束符异常”“输出风格漂移”。

device_map / torch_dtype(加载策略与显存治理)

大模型加载的关键旋钮是:把权重放在哪(GPU/CPU/磁盘)与用什么 dtype(fp32/fp16/bf16)。 device_map="auto" 会尝试把层自动分配到设备上,通常需要安装 accelerate; torch_dtype="auto" 会按权重与硬件能力选择合适 dtype。

Python
1
2
3
4
5
model = AutoModelForCausalLM.from_pretrained(
  "./model_dir",
  torch_dtype="auto",
  device_map="auto",
)
常见坑(高频报错与修复动作)
现象 根因 修复动作
ImportError: Using device_map requires Accelerate 启用了 device_map,但环境缺少 accelerate 安装 pip install -U accelerate,或移除 device_map 并手动 model.to(device)。
Decoder-only 推理报 padding 相关错误 tokenizer 没有 pad_token tok.pad_token = tok.eos_token,并设置 pad_token_id。
推理输出乱码或 EOS 提前结束 tokenizer 与模型不匹配,或 chat template 不一致 确保 tokenizer 与模型来自同一目录;推理统一用 apply_chat_template。
本地目录加载失败(找不到 config/tokenizer) 目录不是标准模型包结构 用 model.save_pretrained 与 tok.save_pretrained 导出;检查是否存在 config.json。
OOM 或极慢 dtype/设备放置策略不合理 优先 torch_dtype="auto" + device_map="auto";必要时启用更强的推理引擎(vLLM/TensorRT-LLM)。
加载第三方模型需要 trust_remote_code=True 模型仓库包含自定义 Python 代码 在受控环境中审计代码后再开启;离线导出时固定 commit hash,避免代码漂移。
PEFT 与微调技术详解

参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)的工程核心是把“可训练参数”从完整模型权重中剥离出来:训练与分发只关心适配器(adapter)的小文件,线上推理再把适配器挂到同一个 base checkpoint 上复用。这样既减少训练时的显存与优化器状态开销,也把多任务/多域适配的存储成本压到可控范围。

安装与版本对齐

PEFT、Transformers、TRL 与 bitsandbytes 在接口上是强耦合组合。工程上以“同一套 requirements 锁定版本”作为默认策略,避免出现 PEFT 与 Transformers 的适配器注入逻辑不一致、或 TRL 的 Trainer 参数签名变化导致脚本失效。

Shell
1
2
3
4
5
# 训练/微调常见最小集合
pip install -U transformers accelerate datasets peft trl safetensors
 
# QLoRA / 4bit 量化微调需要
pip install -U bitsandbytes
PEFT 的对象模型:base 与 adapter

PEFT 的存储与加载分两层:

  • base:Transformers 模型原始 checkpoint(通常很大、可复用、版本需固定)。
  • adapter:PEFT 生成的小文件(含 adapter_config 与 adapter weights),可多份并存,用于不同任务/域。

标准做法是:训练输出目录只保存 adapter;上线推理时先加载 base,再加载 adapter。这样 adapter 目录可被当作“制品”(artifact)管理,支持灰度、回滚与多 adapter 切换。

核心 API 速查
对象 / 函数 用途 典型用法
LoraConfig LoRA/QLoRA 的配置对象
Python
1
2
3
4
5
6
7
8
from peft import LoraConfig, TaskType
cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=8,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
get_peft_model 把 base model 包装成可训练的 PeftModel
Python
1
2
from peft import get_peft_model
peft_model = get_peft_model(base_model, cfg)
PeftModel.from_pretrained 给已加载的 base model 挂载某个 adapter
Python
1
2
from peft import PeftModel
model = PeftModel.from_pretrained(base_model, "adapter_dir")
model.save_pretrained 保存 adapter(不覆盖 base)
Python
1
model.save_pretrained("adapter_out")
model.merge_and_unload 把 adapter 合并进 base 权重并卸载 adapter(用于导出单体权重)
Python
1
merged = model.merge_and_unload()
model.print_trainable_parameters 自检:确认“只训练 adapter”而不是误训全参
Python
1
model.print_trainable_parameters()
LoRA:PEFT 的主力路径

LoRA(Low-Rank Adaptation)通过对线性层权重施加低秩增量,让可训练参数规模与显存开销显著下降。实际工程难点集中在两处:target_modules 怎么选,以及如何保存/加载/合并。

target_modules:如何定位注入点

target_modules 是 LoRA 注入的“模块名匹配规则”,用于指定 base 模型里哪些子模块会被替换/包裹。这一选择与“按层类型挑选线性层”不同:它依赖模型内部的命名约定。不同架构的模块命名差异很大(Llama 系列常见 q_proj/k_proj/v_proj/o_proj,也有模型用 Wq/Wk/Wv/Wo 或把投影层藏在自定义块里)。稳定做法是先枚举可疑线性层,再按名字筛选。

Python
1
2
3
4
5
6
7
8
9
10
11
12
import torch
 
def list_linear_module_names(model: torch.nn.Module):
    names = []
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):
            names.append(name)
    return names
 
# 调试:先看前几十个,确定投影层的命名风格
for n in list_linear_module_names(model)[:50]:
    print(n)

LoRA 常见的注入策略是“注意力投影层优先”:先只覆盖 attention 的 Q/K/V/O,再根据效果与显存预算扩展到 FFN 的投影层(例如 gate/up/down)。

modules_to_save:适配器之外还要训练哪些层

adapter 之外偶尔需要训练额外模块(例如分类头、语言模型的 lm_head、或新加 token 的 embedding)。这类模块可以通过 modules_to_save 声明为可训练并随 adapter 一起保存,避免“训练时更新了头部,但保存的 adapter 不包含它”的上线故障。

LoRA 最小训练骨架(Transformers Trainer)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, get_peft_model
 
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=8,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
model = get_peft_model(base, cfg)
model.print_trainable_parameters()
 
ds = load_dataset("trl-lib/Capybara", split="train")
 
def tokenize(example):
    # 这里只示意:真实项目通常先把 messages/prompt-completion 统一成文本,再 tokenize
    text = example["text"] if "text" in example else str(example)
    return tok(text, truncation=True, max_length=1024)
 
ds = ds.map(tokenize, remove_columns=ds.column_names)
 
args = TrainingArguments(
    output_dir="out_lora_adapter",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    num_train_epochs=1,
    fp16=True,
    logging_steps=10,
    save_steps=200,
)
 
trainer = Trainer(model=model, args=args, train_dataset=ds)
trainer.train()
 
# 只保存 adapter(推荐默认产物)
model.save_pretrained("out_lora_adapter")
保存、加载与多 adapter 切换

PEFT 支持一个 base 上挂多份 adapter,并在推理时切换 active adapter。这种能力适合“同一底座,多业务域”的线上形态。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
# 加载第一个 adapter
model = PeftModel.from_pretrained(base, "adapter_a", adapter_name="a")
model.set_adapter("a")
 
# 再加载第二个 adapter(同一个 base 上叠加管理)
model.load_adapter("adapter_b", adapter_name="b")
 
# 推理时切换
model.set_adapter("b")
merge:合并成单体权重(用于导出/部署)

merge_and_unload() 用于把 LoRA 权重写回 base 权重,再移除 adapter 结构,得到“标准 Transformers 模型结构”。工程上这通常用于:

  • 导出到单体 checkpoint(例如给不支持 adapter 的推理引擎)。
  • 降低线上复杂度(不需要两段加载与 adapter 切换)。

合并前需要确认 base 权重处于可写入的浮点 dtype(fp16/bf16/fp32)。对 4-bit 量化权重,合并通常不作为默认路径。

Python
1
2
3
4
5
6
7
from transformers import AutoModelForCausalLM
from peft import PeftModel
 
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="float16", device_map="cpu")
model = PeftModel.from_pretrained(base, "adapter_out")
merged = model.merge_and_unload()
merged.save_pretrained("out_merged_full")
QLoRA:量化权重 + LoRA 的组合

QLoRA 的工程语义是:base 权重以 4-bit 量化形式加载并冻结,反向传播只更新 LoRA 参数;量化层负责把反向信号“传递”到 LoRA,而不是更新量化权重本身。

QLoRA 关键 API(BitsAndBytesConfig + prepare_model_for_kbit_training)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
model_id = "mistralai/Mistral-7B-v0.1"
 
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)
 
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,
    device_map="auto",
)
 
# 让量化模型进入可训练形态(冻结 base、处理 dtype/层归一化等)
base = prepare_model_for_kbit_training(base)
 
cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=8,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
model = get_peft_model(base, cfg)
model.print_trainable_parameters()
量化 + adapter 的工程约束
  • 合并策略:4-bit 量化权重通常不作为合并目标。需要单体权重时,常见流程是“重新加载 fp16/bf16 base → 挂载 adapter → merge → 导出”。
  • 训练开关:Decoder-only 模型训练时常需要关闭 use_cache,并配合 gradient checkpointing 控制显存。
  • 部署形态:adapter 目录天然适合做制品;量化 base 属于环境相关资产(与推理后端、算子实现与硬件强相关),需要独立版本管理。
IA3:更轻量的向量型适配器

IA3(Infused Adapter by Inhibiting and Amplifying Inner Activations)通过在注意力与前馈模块中注入少量可训练向量来缩放激活值。它的可训练参数通常比 LoRA 更少,适合“极低成本的快速适配”场景。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import IA3Config, TaskType, get_peft_model
 
model_id = "bigscience/mt0-large"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
cfg = IA3Config(
    task_type=TaskType.SEQ_2_SEQ_LM,
    target_modules=["k", "v", "wo"],  # 示例:以实际模型命名为准
)
model = get_peft_model(base, cfg)
model.print_trainable_parameters()
Prompt Tuning:软提示词(soft prompt)

Prompt tuning 把“要学习的东西”压缩为一段可训练的虚拟 token embedding(virtual tokens),base 权重保持冻结。它的工程接口与 LoRA 不同:训练的是提示向量,不是注入线性层的低秩权重。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import (
    PromptTuningConfig,
    PromptTuningInit,
    TaskType,
    get_peft_model,
)
 
model_id = "bigscience/bloomz-560m"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
cfg = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    num_virtual_tokens=8,
    tokenizer_name_or_path=model_id,
)
 
model = get_peft_model(base, cfg)
model.print_trainable_parameters()

Prompt tuning 的数据组织更接近“提示 + 输入 + 输出”的模板化任务。工程上需要明确:提示文本、虚拟 token 数量、以及 tokenizer 的特殊 token 处理方式,三者必须一致,否则会出现训练可收敛但推理表现异常的对齐问题。

与 Transformers / TRL 的集成方式
Transformers:PeftAdapterMixin

Transformers 在模型类上集成了适配器管理接口,典型能力包括 add_adapter、 load_adapter、 set_adapter 与适配器保存。LoRA/IA3/AdaLoRA 属于常见的“直接集成”方法;prompt tuning 等提示类方法通常直接用 PEFT 库完成更稳定。

TRL:把 PEFT 当作模型构造步骤

TRL 的 Trainer(SFT/DPO/GRPO/PPO 等)工程上更适合把 PEFT 当作“模型构造步骤”:先用 Transformers 加载 base,再用 PEFT 包一层 adapter,把得到的 PeftModel 直接传给 TRL Trainer。这样可以绕开不同 TRL 版本对 peft_config 参数支持度不一致的问题。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
 
train_ds = load_dataset("trl-lib/Capybara", split="train")
 
cfg = SFTConfig(
    output_dir="out_trl_lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    logging_steps=10,
    save_steps=200,
    num_train_epochs=1,
)
 
trainer = SFTTrainer(
    model=model,       # 这里直接传 PEFT 包装后的 model
    tokenizer=tok,
    train_dataset=train_ds,
    args=cfg,
)
trainer.train()
model.save_pretrained("out_trl_lora_adapter")
DeepSpeed 详解

DeepSpeed 的工程接口由两部分构成:一部分是 launcher + 分布式初始化,负责把“单机脚本”变成“多进程多卡训练”;另一部分是 DeepSpeedEngine + JSON 配置,负责把显存分片(ZeRO)、offload、混合精度、梯度累积、checkpoint 等能力落在可复现的配置上。本节围绕安装、启动、 deepspeed.initialize、 ds_config.json、ZeRO Stage 1/2/3、offload、checkpoint,以及与 Transformers/Accelerate 的集成路径展开,重点给出配置与代码的对应关系。

安装与环境验证
基础安装

DeepSpeed 的最小安装路径是先安装 PyTorch,再安装 DeepSpeed。DeepSpeed 包含若干 C++/CUDA 扩展(ops),默认采用 JIT 方式在运行期编译加载,因此环境里通常需要可用的编译链与 ninja。

Shell
1
2
3
4
pip install deepspeed
 
# 可选:Transformers 侧一次性装好集成依赖
pip install "transformers[deepspeed]"
环境报告(ds_report)

安装完成后优先跑环境报告,确认“哪些 ops 可用、哪些会在运行时编译、CUDA/通信栈是否匹配”。这个步骤在排查安装或性能差异时比直接跑训练更高效。

Shell
1
2
3
4
ds_report
 
# 等价入口
python -m deepspeed.env_report
预编译 ops(可选)

默认 JIT 编译适合研发迭代;在固定镜像或需要减少“首次运行抖动”的场景里,可以在安装期预编译部分或全部 ops。DeepSpeed 提供一组 DS_BUILD_* 环境变量控制构建范围。

Shell
1
2
3
4
5
# 尝试构建所有 ops(只会构建与当前机器兼容的部分)
DS_BUILD_OPS=1 pip install deepspeed
 
# 只构建某一类 op(示例:FusedLamb)
DS_BUILD_FUSED_LAMB=1 pip install deepspeed

预编译全部 ops 可能耗时较长,可通过并行编译加速:

Shell
1
DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8"
Launcher 与进程启动
单机多卡

DeepSpeed launcher 的默认约定是“一进程一 GPU”。launcher 会为脚本注入 --local_rank,脚本侧需要能解析这个参数并把当前进程绑定到对应 GPU。

Shell
1
2
# 单机 8 卡
deepspeed --num_gpus=8 train.py --deepspeed --deepspeed_config ds_config.json
多机

多机训练通常由 launcher 读取 hostfile(节点列表与每节点 slots),并在每个节点上拉起相同脚本。hostfile 格式依赖部署系统(裸机/Slurm/K8s),工程上常见做法是先让调度系统分配机器与 GPU,再由 DeepSpeed 或 torchrun 建立通信。

hostfile(示例)
1
2
node0 slots=8
node1 slots=8

Shell
1
deepspeed --hostfile=hostfile train.py --deepspeed --deepspeed_config ds_config.json
脚本侧参数解析

DeepSpeed 提供 deepspeed.add_config_arguments 把 --deepspeed 与 --deepspeed_config 等参数接入到自定义 argparse 中。

Python
1
2
3
4
5
6
7
import argparse
import deepspeed
 
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
deepspeed.initialize 与训练循环
最小接入骨架

deepspeed.initialize 是训练入口:负责(必要时)初始化 torch distributed,并返回一个可直接用于 forward/backward/step 的 DeepSpeedEngine。配置文件里的 optimizer/scheduler/dataloader 也可以被 DeepSpeed 构造与管理。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import argparse
import torch
import deepspeed
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
 
    model = MyModel()
 
    # optimizer 既可以在 ds_config.json 里声明,也可以由代码创建并传入
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=model.parameters(),
    )
 
    for batch in train_loader:
        loss = model_engine(batch)
        model_engine.backward(loss)
        model_engine.step()
 
if __name__ == "__main__":
    main()
分布式初始化的边界

当脚本里已经显式调用了 torch.distributed.init_process_group,DeepSpeed 侧应改为 deepspeed.init_distributed 或直接移除显式初始化,让 deepspeed.initialize 自动完成分布式初始化。多重初始化是常见的 hang 根源。

配置与代码的覆盖规则

DeepSpeed 的核心覆盖规则是:配置文件定义默认行为,显式传入的 Python 对象覆盖配置。例如,当在 deepspeed.initialize 里传入 optimizer 时,会覆盖 ds_config.json 里 optimizer 段落的定义。

ds_config.json:把“训练策略”编码进配置
批大小三件套与推导关系

DeepSpeed 将 batch size 拆为三项参数:有效 batch( train_batch_size)、每卡 micro-batch( train_micro_batch_size_per_gpu)、梯度累积步数( gradient_accumulation_steps)。三者满足:

\[B = b \\times g \\times N\]

其中 \(B\) 对应 train_batch_size,\(b\) 对应 train_micro_batch_size_per_gpu,\(g\) 对应 gradient_accumulation_steps,\(N\) 是参与训练的 GPU 数量(即 world size)。

工程上通常只显式指定其中两个,剩下一个由 DeepSpeed 推导;这样可以减少多机扩容时的人工改动。

一份可跑的最小配置
ds_config_min.json
JSON
1
2
3
4
5
6
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": { "stage": 2 }
}
配置-代码对照表
配置项 DeepSpeed 行为 代码侧需要做什么
train_micro_batch_size_per_gpu 定义每次 forward/backward 的 micro-batch 大小 DataLoader 提供的 batch 必须与该值一致,或让上层框架(Trainer/Accelerate)保持一致
gradient_accumulation_steps 定义多少个 micro-step 后做一次参数更新 训练循环仍按 “每个 batch 一次 backward”,DeepSpeedEngine 内部按配置决定何时 step
fp16.enabled / bf16.enabled 启用混合精度与 loss scaling(若需要) 脚本不再手写 AMP 也能跑通;若与外部 AMP 同时启用,需明确由谁负责 autocast/scaler
optimizer DeepSpeed 构造优化器(可选) 如果在 deepspeed.initialize 显式传入 optimizer,则会覆盖该段配置
scheduler DeepSpeed 构造并在每步自动 step(可选) 当 scheduler 由 DeepSpeed 管理时,脚本不应额外调用 scheduler.step()
ZeRO:Stage 1/2/3 与 Offload
Stage 1:分片优化器状态

ZeRO Stage 1 将优化器状态(例如 Adam 的一阶/二阶动量与 FP32 master 权重)在 data-parallel ranks 之间分片,降低“优化器状态显存/内存”的重复开销。对模型参数量中等但 optimizer state 占用成为瓶颈的训练很直接。

ds_zero_stage1.json
JSON
1
2
3
4
5
6
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": { "stage": 1 }
}
Stage 2:再分片梯度

ZeRO Stage 2 在 Stage 1 的基础上将梯度也分片。它通常在“模型能放下,但训练状态占用过高”或“希望进一步扩大 batch/seq_len”时成为默认选择;相较 Stage 3,它在通信与实现复杂度上更温和。

ds_zero_stage2.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
Stage 3:再分片参数(最省显存)

ZeRO Stage 3 进一步把模型参数也分片,使得“单卡显存”主要由激活与少量 shard 状态构成,从而把可训练模型尺度推到显存上限之外。代价是更多通信与参数聚合/分片的复杂性,checkpoint 与恢复也会更敏感。

ds_zero_stage3.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "bf16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true
  }
}
Offload:把状态搬到 CPU/NVMe

offload 的目标是继续压缩 GPU 显存占用。常见组合是:ZeRO-2 offload optimizer states(CPU)与 ZeRO-3 offload params/optimizer(CPU 或 NVMe)。offload 会把瓶颈从显存转移到带宽与延迟,因此通常需要配合更细的 micro-batch、更高的梯度累积与更强的通信/计算重叠。

ds_zero3_cpu_offload.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 8,
  "fp16": { "enabled": true },
  "zero_optimization": {
    "stage": 3,
    "offload_param": { "device": "cpu", "pin_memory": true },
    "offload_optimizer": { "device": "cpu", "pin_memory": true }
  }
}
Checkpoint:保存、恢复与导出
Engine 级保存/恢复

DeepSpeedEngine 提供 save_checkpoint / load_checkpoint,用于保存与恢复模型、优化器、scheduler 以及自定义 client_state。工程要点是:所有 ranks 都必须调用 save_checkpoint,否则会在同步点 hang。ZeRO-3 下,保存后立刻在同一 engine 上 load(不重新初始化)是已知的不兼容用法。

Python
1
2
3
4
5
6
# 保存(所有进程都会参与)
model_engine.save_checkpoint("ckpt_dir", tag=f"global_step{global_step}", client_state={"step": global_step})
 
# 恢复(通常在初始化后尽早执行)
load_path, client_state = model_engine.load_checkpoint("ckpt_dir", tag=None)
global_step = client_state.get("step", 0)
ZeRO checkpoint 权重导出(fp32 合并)

ZeRO-2/3 的 checkpoint 是分片形态。需要“脱离 DeepSpeed 继续使用/分享权重”时,常用做法是把 ZeRO checkpoint 转为合并后的 fp32 state_dict。

Python
1
2
3
4
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
 
fp32_state_dict = get_fp32_state_dict_from_zero_checkpoint("ckpt_dir", tag=None)
torch.save(fp32_state_dict, "pytorch_model_fp32.bin")
与 Transformers 的集成
Trainer / TrainingArguments 接入

Transformers 的 Trainer 通过 TrainingArguments.deepspeed(或 CLI 的 --deepspeed)接入 DeepSpeed。工程上更稳定的做法是:把与 Trainer 重复的值在 ds_config 中写成 "auto",由 Trainer 统一灌入,避免“两边都写但不一致”。

ds_hf_auto.json
JSON
1
2
3
4
5
6
7
8
9
10
{
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "optimizer": {
    "type": "AdamW",
    "params": { "lr": "auto" }
  },
  "fp16": { "enabled": "auto" },
  "zero_optimization": { "stage": 2 }
}

Python
1
2
3
4
5
6
7
8
9
10
from transformers import TrainingArguments
 
args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    fp16=True,
    deepspeed="ds_hf_auto.json",
)
与 Accelerate 的集成
用 DeepSpeedPlugin(代码内指定)

Accelerate 提供 DeepSpeedPlugin,把 ZeRO stage、梯度累积等关键项绑定到 Accelerator 的生命周期里。工程要点是:DeepSpeed 需要提前知道 gradient_accumulation_steps,因此插件与训练循环要对齐,梯度累积本身仍需要按常规方式在代码里实现。

Python
1
2
3
4
from accelerate import Accelerator, DeepSpeedPlugin
 
ds_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
用 DeepSpeed 配置文件(更强的可控性)

当需要 ZeRO-3/offload/更细的 ZeRO knobs 时,Accelerate 通常通过“指定 DeepSpeed 配置文件”的方式接入。此时 ds_config.json 才是事实来源,代码侧只保留必要的 accelerator.prepare 与 accelerator.backward 语义,避免重复配置。

Shell
1
2
3
# 通过 accelerate config 生成运行配置后,再用 accelerate launch 运行训练脚本
accelerate config
accelerate launch train.py
vLLM 详解

vLLM 是面向服务化推理的运行时:围绕高吞吐调度、KV cache 管理、continuous batching、分布式并行与 OpenAI-compatible API 提供一体化推理栈。工程落地时可以把 vLLM 当作三条“入口路径”:离线推理的 LLM,可嵌入自建服务的 Engine,以及直接上线的 vllm serve。

安装路径与环境兼容

vLLM 的 wheel 包含大量编译好的 C++/GPU kernels。性能与兼容性高度依赖“vLLM wheel、PyTorch、驱动/运行时”三者的组合,工程上优先使用官方提供的预构建 wheel 或官方 Docker 镜像。

GPU 安装(推荐路径)

官方文档建议在新环境中安装 vLLM,并优先使用 wheel 自带的 PyTorch/依赖组合以减少二进制不兼容问题;此外,conda 安装的 PyTorch 可能静态链接 NCCL,容易在分布式/多进程场景引发问题。

Install vLLM (GPU): create env & install (pattern)
Shell
1
2
3
4
5
6
# create a clean env (example with uv)
uv venv --python 3.12 --seed
source .venv/bin/activate
 
# install vLLM
uv pip install vllm

基本自检可以用“导入 + 小模型离线生成”验证:

Sanity check: offline generate
Python
1
2
3
4
5
6
from vllm import LLM, SamplingParams
 
llm = LLM(model="facebook/opt-125m", enforce_eager=True)
params = SamplingParams(max_tokens=16, temperature=0.0)
out = llm.generate(["Hello, my name is"], params)
print(out[0].outputs[0].text)
Docker 安装(生产最常用)

生产系统更常直接使用官方镜像运行 OpenAI-compatible server。多进程与张量并行依赖共享内存,容器启动通常需要 --ipc=host 或显式配置 --shm-size。

vLLM official Docker image: run OpenAI-compatible server
Shell
1
2
3
4
5
6
7
8
9
export HF_TOKEN="<secret>"
 
docker run --gpus all \\
  -v ~/.cache/huggingface:/root/.cache/huggingface \\
  --env "HF_TOKEN=$HF_TOKEN" \\
  -p 8000:8000 \\
  --ipc=host \\
  vllm/vllm-openai:latest \\
  --model Qwen/Qwen3-0.6B

当宿主机驱动较旧时,官方镜像提供 CUDA compatibility 模式(只覆盖部分专业/数据中心 GPU 的兼容场景):

Docker: enable CUDA compatibility libraries (pattern)
Shell
1
2
3
4
5
6
7
8
docker run --gpus all \\
  -v ~/.cache/huggingface:/root/.cache/huggingface \\
  -p 8000:8000 \\
  --env "HF_TOKEN=$HF_TOKEN" \\
  --env "VLLM_ENABLE_CUDA_COMPATIBILITY=1" \\
  --ipc=host \\
  vllm/vllm-openai:latest \\
  --model Qwen/Qwen3-0.6B
何时需要从源码构建

当你的 CUDA/ROCm 版本、PyTorch 构建配置或硬件平台与官方 wheel 不匹配时,需要从源码构建。官方提供了以 VLLM_USE_PRECOMPILED=1 作为起点的可编辑安装,以及基于 CMake 的增量编译工作流用于迭代 kernels。

Build-from-source (editable) + incremental build toolchain (pattern)
Shell
1
2
3
4
5
6
7
8
9
10
11
git clone https://github.com/vllm-project/vllm.git
cd vllm
 
uv venv --python 3.12 --seed
source .venv/bin/activate
 
# editable install (use precompiled wheels where possible to speed up)
VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto
 
# build toolchain for incremental compilation
uv pip install -r requirements/build.txt --torch-backend=auto
三条接口:LLM / Engine / vllm serve

vLLM 的 API 结构可以用“离线批处理推理”“可嵌入的引擎”“生产服务端”三条路径来理解。三者底层共享同一套 Engine 配置(EngineArgs),差异在于请求进入方式与生命周期管理。

接口 1:LLM(离线批处理推理)

vllm.LLM 适合离线批处理与数据集推理。它接受 prompts 列表并返回结构化输出,常用于离线评测、数据合成与批量生成。

LLM: batched offline inference
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from vllm import LLM, SamplingParams
 
prompts = [
    "Hello, my name is",
    "The capital of France is",
]
params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=64)
 
llm = LLM(model="facebook/opt-125m")
outputs = llm.generate(prompts, params)
 
for o in outputs:
    print(o.prompt)
    print(o.outputs[0].text)

采样参数的默认来源有两套:模型仓库里的 generation_config 与 vLLM 自己的默认值。若业务希望显式使用 vLLM 的默认采样参数,可以在创建 LLM 时设置:

LLM: control generation_config source (pattern)
Python
1
2
3
4
5
6
from vllm import LLM
 
llm = LLM(
    model="facebook/opt-125m",
    generation_config="vllm",
)
接口 2:Engine(嵌入式引擎,用于自建服务)

Engine 路线用于把 vLLM 嵌入到自建服务/作业系统中,获得“请求级 streaming + 细粒度生命周期控制”。当前主线接口是 V1 Engine( AsyncLLM),通过 AsyncEngineArgs 构建。

Engine: AsyncLLM streaming generate (pattern)
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import asyncio
 
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
 
 
async def main() -> None:
    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-3.2-1B-Instruct",
        enforce_eager=True,  # examples: faster startup, lower peak perf
    )
    engine = AsyncLLM.from_engine_args(engine_args)
    try:
        params = SamplingParams(
            max_tokens=64,
            temperature=0.2,
            output_kind=RequestOutputKind.DELTA,  # only new tokens each iteration
        )
        async for out in engine.generate(
            request_id="req-1",
            prompt="Write a haiku about caching.",
            sampling_params=params,
        ):
            for c in out.outputs:
                if c.text:
                    print(c.text, end="", flush=True)
            if out.finished:
                break
    finally:
        engine.shutdown()
 
 
if __name__ == "__main__":
    asyncio.run(main())

在 Engine 路线里,“并发/显存预算”通常通过 EngineArgs 控制,应用侧需要自行处理:请求队列、超时/取消(abort)、重试、以及与外部网关的对接。

接口 3:vllm serve(生产服务端)

vllm serve 直接启动 OpenAI-compatible server,是最接近“拿来就用”的生产入口。典型端点包括 /v1/chat/completions 与 /v1/embeddings,并支持流式输出(SSE)。

vllm serve: start OpenAI-compatible server
Shell
1
2
3
4
5
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --host 0.0.0.0 \\
  --port 8000 \\
  --dtype auto \\
  --api-key token-abc123

OpenAI-compatible request (generic curl)
Shell
1
2
3
4
5
6
7
8
curl http://localhost:8000/v1/chat/completions \\
  -H 'Content-Type: application/json' \\
  -H 'Authorization: Bearer token-abc123' \\
  -d '{
    "model": "meta-llama/Meta-Llama-3-8B-Instruct",
    "messages": [{"role":"user","content":"Hello!"}],
    "stream": true
  }'
OpenAI Python SDK(base_url 与 extra_body)

应用侧最常见的接入方式是 OpenAI Python SDK,把 base_url 指向自托管服务。部分参数在 OpenAI API 中不存在,但 vLLM 支持;这类扩展字段通常通过 extra_body 传入。

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
vLLM (pattern)">
from openai import OpenAI
 
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
 
resp = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0.2,
    max_tokens=128,
    extra_body={"top_k": 50},
)
print(resp.choices[0].message)
服务端配置文件(YAML)

vllm serve 支持从 YAML 配置文件加载参数。参数名使用长参数形式(long form)。CLI 与配置文件同时提供时,优先级为 CLI > config > defaults。

vLLM serve config.yaml (pattern)
YAML
1
2
3
4
5
6
7
8
9
10
model: meta-llama/Llama-3.1-8B-Instruct
host: "0.0.0.0"
port: 8000
uvicorn-log-level: "info"
api-key: "token-abc123"
dtype: "auto"
max-model-len: 8192
gpu-memory-utilization: 0.90
max-num-seqs: 64
enable-prefix-caching: true

vLLM serve: launch with config
Shell
1
vllm serve --config config.yaml
EngineArgs(核心配置面)

Engine arguments 控制 vLLM 的运行行为:离线推理时它们是 LLM(...) 的一部分参数;在线服务时它们是 vllm serve 的参数子集。工程上可以把 EngineArgs 按职责分为四类:模型与 tokenizer、并行与执行器、KV cache 与调度、以及安全/可观测性。

常用 EngineArgs(服务端视角)
参数 含义 工程后果
--max-model-len 最大上下文长度 直接决定 KV cache 的 token 预算;过大常导致并发下降或 OOM
--gpu-memory-utilization 显存预算比例 留出系统/碎片空间;过高会提升 OOM 风险
--max-num-batched-tokens 每步调度的 token 预算 影响吞吐与尾延迟,常与并发/显存一起调
--max-num-seqs 并发序列数上限 决定同卡并发,过高会导致排队抖动与尾延迟上升
--kv-cache-dtype KV cache 存储精度 影响显存/带宽;更激进精度需要评估质量与稳定性
--trust-remote-code 允许加载模型仓库的自定义代码 改变执行边界;仅在可信模型源启用
--download-dir 权重/缓存下载目录 容器化时用于挂载共享缓存,减少冷启动成本
EngineArgs 在 Python 侧的用法

当你希望把“服务端配置”复用到离线任务中,可以在 Python 里用 EngineArgs 组装配置,再把字段展开给 LLM:

Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
LLM (pattern)">
from dataclasses import asdict
 
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
 
engine_args = EngineArgs(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    max_model_len=8192,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)
 
llm = LLM(**asdict(engine_args))
并行:Tensor Parallel / Pipeline Parallel / Data Parallel

分布式推理的目标是“把单模型副本放进足够多的 GPU 里,并把负载分摊出去”。vLLM 的并行策略可以分为三类:单副本的张量并行/流水并行,以及多副本的 data parallel(权重复制)。

Tensor Parallel(单机多卡)

当模型无法放进单卡但能放进单机多卡时,设置 tensor_parallel_size 为“每节点 GPU 数”是最常见策略。

vllm serve: tensor parallel (single node)
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --tensor-parallel-size 4 \\
  --host 0.0.0.0 --port 8000
Pipeline Parallel(多机或单机不均匀切分)

当模型超过单机容量,需要组合张量并行与流水并行:把 tensor_parallel_size 设为“每节点 GPU 数”,把 pipeline_parallel_size 设为“节点数”。如果模型在单机可容纳,但 GPU 数无法均匀切分模型,也可以用 pipeline parallel 做不均匀切分:此时常见设置是 tensor_parallel_size=1, pipeline_parallel_size=GPU 数。

vllm serve: tensor+pipeline parallel (pattern)
Shell
1
2
3
4
vllm serve meta-llama/Meta-Llama-3-70B-Instruct \\
  --tensor-parallel-size 8 \\
  --pipeline-parallel-size 2 \\
  --host 0.0.0.0 --port 8000
Data Parallel(多副本 + 负载均衡)

data parallel 复制权重,让多个 GPU/进程独立处理请求,适合吞吐扩展。vLLM 支持“自包含 DP(一个对外端点,内部做 rank 级负载均衡)”与“外部负载均衡(每 rank 单独对外,外部 LB 路由)”。

vllm serve: data parallel (self-contained, pattern)
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --data-parallel-size 4 \\
  --host 0.0.0.0 --port 8000
前缀缓存(Prefix Caching)

前缀缓存缓存“已 prefill 的前缀 KV blocks”,新请求与历史请求共享前缀时,可以复用缓存并跳过重复的 prefill 计算。它对“系统 prompt 固定、RAG 模板固定、长上下文重复”的场景收益很大。

开启方式与哈希策略

服务端通过 --enable-prefix-caching 开启前缀缓存。为多租户隔离与碰撞风险控制,前缀缓存提供可配置的哈希策略(例如使用 SHA256 族)。

vllm serve: enable prefix caching (pattern)
Shell
1
2
3
4
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --enable-prefix-caching \\
  --prefix-caching-hash-algo sha256 \\
  --host 0.0.0.0 --port 8000
Speculative Decoding

speculative decoding 用“草稿模型提出多个候选 token + 目标模型验证并接收其中一部分”的方式减少目标模型 decode 步数,从而降低解码延迟。服务端需要同时加载目标模型与草稿模型,并为草稿模型设置独立的资源预算。

vllm serve: speculative decoding (CLI pattern)
Shell
1
2
3
4
5
6
vllm serve <target-model> \\
  --speculative-config '{
    "method": "draft_model",
    "model": "<draft-model>",
    "num_speculative_tokens": 5
  }'

speculative decoding 与某些并行策略(例如 pipeline parallel)可能存在兼容性限制,上线前需要在目标版本组合上做压测与回归。

监控、日志与部署注意项
/metrics 与 Prometheus

vLLM 的 OpenAI-compatible server 默认暴露 /metrics,可用于 Prometheus 抓取与容量规划。

Query /metrics
Shell
1
curl http://0.0.0.0:8000/metrics

容量规划时需要重点关注两类信息:KV cache 的 token 容量与“最大并发估计”。vLLM 启动日志通常会输出类似的估算信息(示例格式如下):

Example: concurrency estimate lines (concept)
1
2
GPU KV cache size: 643,232 tokens
Maximum concurrency for 40,960 tokens per request: 15.70x
健康检查与日志降噪

服务端通常提供健康检查端点(例如 /health、 /ping)。生产环境里这些端点会被 LB 高频调用,建议通过 --disable-access-log-for-endpoints 关闭对应 access logs,避免淹没有效日志:

Disable access logs for noisy endpoints
Shell
1
2
3
vllm serve meta-llama/Meta-Llama-3-8B-Instruct \\
  --disable-access-log-for-endpoints "/health,/metrics,/ping" \\
  --host 0.0.0.0 --port 8000
日志配置(环境变量)

vLLM 使用 Python 的 logging 配置体系,并提供环境变量控制默认日志行为。最常见的两类控制是:关闭 vLLM 的默认日志配置,以及提供自定义 JSON logging 配置文件路径。

Logging controls (pattern)
Shell
1
2
3
4
5
6
# disable vLLM logging configuration
export VLLM_CONFIGURE_LOGGING=0
 
# or provide a custom logging config file
export VLLM_CONFIGURE_LOGGING=1
export VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json
部署前的稳定性检查清单
  • 显存预算:用 --max-model-len 与 --gpu-memory-utilization 先跑通,再逐步提高 --max-num-seqs 与 --max-num-batched-tokens 做压测。
  • 默认行为:明确 chat template、tokenizer 与 generation_config 的来源与优先级,避免升级后默认采样参数变化。
  • 权限边界: --trust-remote-code 只在可信模型源启用;容器中通过只读挂载、最小权限与镜像固化降低风险。
  • 日志与指标:确保 /metrics 可被抓取,健康检查端点与 access log 策略不会引发噪声或误报警。
附录
常见陷阱与排障速查

这一节只覆盖 ref-6 正文里已经出现过的栈:PyTorch / Transformers / Accelerate / PEFT / TRL / DeepSpeed / vLLM / RAG 向量组件,以及常见的日志与实验跟踪工具(TensorBoard、W&B、MLflow、Langfuse)。每条都给“现象→快速检查→修复动作”。

环境与依赖(安装层)
现象 快速检查 常见根因 修复动作
ImportError: Using device_map requires Accelerate 看代码是否启用了 device_map="auto" Transformers 需要 accelerate 提供 device map / offload 运行时 pip install -U accelerate,或移除 device_map 并手动 model.to(device)
DeepSpeed 安装成功但首次训练很慢 跑环境报告/查看是否在编译 ops DeepSpeed ops 走 JIT 编译,首次运行会编译 CUDA/C++ 扩展 把编译链(gcc/g++/ninja)与 CUDA toolkit 固定到镜像;或使用预编译/缓存编译产物
bitsandbytes/FlashAttention/xFormers 安装失败 核对 torch.__version__ 与 CUDA/驱动 二进制 wheel 与 CUDA/torch 组合不匹配,退回源码编译 优先选“官方支持矩阵”内的 torch+CUDA 组合;必要时换到对应 wheel 或统一用容器镜像
huggingface_hub 下载慢/缓存爆盘 检查缓存路径与磁盘配额 默认缓存落在 home 盘;大模型/多版本重复下载 设置 HF_HOME/ HF_HUB_CACHE 到大盘;固定模型版本,避免反复下载
训练侧(脚本、分布式与 checkpoint)
现象 快速检查 常见根因 修复动作
训练 loss 正常,但 eval 指标不动或波动异常 检查 eval 集是否固定、是否数据泄漏、是否用错 metric 数据切分不稳定;指标与 loss 不一致(例如二分类用 F1 优先) 固定 eval 集与随机种子;按任务选择 monitor(生成任务常用 token acc/下游指标;分类任务多用 F1/acc)
同样配置多次训练结果差异大 检查是否固定 seed;是否启用非确定性算子 并行归约顺序、不同 kernel、随机种子未固定 固定 seed;能用 deterministic kernel 的框架启用 deterministic;记录环境/依赖版本到 run_meta
断点续训后 learning rate/step 计数异常 检查是否从正确的 checkpoint 恢复 optimizer/scheduler 只恢复了模型权重,没恢复优化器状态;或切换了 batch size/accumulation 统一用框架提供的 resume 机制;恢复后不随意改动 batch/accumulation;把超参固化在 run_meta
LoRA 训练完线上加载无效果 确认线上是否真的挂载了 adapter;是否用对 base base checkpoint 不一致;adapter 没加载或没 set_active 把 base_model_id 写入 adapter 元信息;上线前做“base+adapter”一致性 smoke test
DeepSpeed/多机训练 hang 或极慢 打开 NCCL 日志;检查网卡选择 NCCL 走错网络接口;IB/PCIe 拓扑或防火墙 设置 NCCL 环境变量并固定网卡;必要时禁用 IB/P2P 作为定位手段
NCCL 诊断(最小集)
Shell
1
2
3
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,NET
export NCCL_SOCKET_IFNAME=eth0
推理侧(vLLM / OpenAI 兼容服务)
现象 快速检查 常见根因 修复动作
服务启动成功但客户端 404 / model not found GET /v1/models 看 model 名 服务端模型名与客户端 model= 不一致 vLLM 启动加 --served-model-name,客户端统一用该名称
服务 QPS 低或频繁 OOM 看 /metrics,关注 KV cache 与并发 context 过长、KV cache 预算过大、GPU mem utilization 过高 降低最大上下文/并发;保留显存余量;把流量打到多副本;按模型大小重新估算 token budget
temperature=0 仍有轻微不一致 确认是否完全禁用采样;是否跨硬件/多实例 服务端并行归约/调度带来非确定性 禁用所有采样相关选项;尽量固定推理硬件与 kernel;把“强确定性”作为服务 SLA 单独约束
端口占用 / 启动失败 lsof -i :8000 端口被旧进程占用 停止旧进程或换端口;把启动/回滚写成脚本并纳入进程管理器
最小健康检查
Shell
1
2
curl -sf http://127.0.0.1:8000/v1/models
curl -sf http://127.0.0.1:8000/metrics | head
RAG 与向量组件(FAISS / pgvector / Qdrant / Milvus / TCVectorDB)
现象 快速检查 常见根因 修复动作
召回结果明显变差(无规律) 检查是否混入不同 embedding 模型版本 同一 collection 混用不同 embedding space 把 embedding_model_id 写入 metadata;版本迁移时重建或双写新集合
cosine 相似度排序不符合预期 检查向量是否归一化;检查 metric 设置 cosine 与 inner product/l2 使用不一致 统一策略:归一化 + inner product(或显式 cosine metric);写入与查询必须一致
pgvector 查询报错:vector 扩展不存在 SELECT extname FROM pg_extension; 未启用扩展 CREATE EXTENSION IF NOT EXISTS vector;
Qdrant 可用但安全风险 检查是否启用鉴权,端口是否暴露公网 默认 Docker QuickStart 无认证 启用 API key/鉴权;仅暴露内网;生产环境用 Helm/Cloud 并加网络策略
Milvus/TCVectorDB 连接失败 检查 endpoint/token/TLS 网络不可达、鉴权不匹配、TLS 配置问题 先用最小 client 读写验证;把 endpoint/token/TLS 作为运行时配置(环境变量/密钥管理)
框架/引擎选型原则(工程视角)
需求 优先选 理由 不适合时的信号
想要最短路径微调 LLM(SFT/DPO/GRPO) TRL + (Transformers + PEFT) + Accelerate 方法流程化、接口稳定、与 HF 生态耦合最深 需要高度自定义训练循环/复杂多任务调度,且团队已有自研训练框架
通用训练循环(分类/NER/CV)且团队多人复用 Transformers Trainer 或 PyTorch Lightning 或 MMEngine 训练工程样板收敛到统一入口;callbacks/loggers/checkpoint 更标准 训练逻辑高度非标准(例如特殊采样/复杂图结构),框架抽象反而增加阻力
大模型训练显存吃紧,需要参数/优化器分片 DeepSpeed ZeRO 或 PyTorch FSDP(常经由 Accelerate) 把显存压力从“模型/优化器/梯度”三个维度拆开 集群/网络不稳定;团队缺少分布式排障经验;checkpoint 迁移频繁
需要高吞吐服务化推理(OpenAI API 兼容) vLLM( vllm serve) continuous batching、KV cache 管理、指标/观测与并行策略更贴近生产 只需离线小批推理且依赖纯 Transformers;或模型结构/算子不被 vLLM 支持
RAG 召回(单机/单租户/延迟极低) FAISS 部署简单、延迟低、索引可控 需要多租户、过滤、持久化、高可用与在线扩缩容
RAG 召回(已有 Postgres 体系) pgvector 事务/权限/JOIN/备份沿用 Postgres,治理成本低 超大规模向量 + 高 QPS;需要专用 ANN 系统能力
RAG 召回(专用向量数据库能力) Qdrant / Milvus / 托管向量库(TCVectorDB) 过滤、持久化、分布式扩展与运维能力更完整 团队没有运维能力(自建风险高);对迁移性要求极强(托管绑定风险)
训练与推理栈术语对照
术语 训练侧含义 推理/服务侧含义 落点(代码/命令)
base / adapter base 是大权重 checkpoint;adapter 是 LoRA 等增量参数 服务端加载 base 后挂载 adapter,或提前 merge 导出 PEFT: PeftModel.from_pretrained / merge_and_unload
checkpoint 训练过程中的可恢复状态(含 optimizer/scheduler 视框架而定) 上线时通常只需要“可加载权重目录”(artifact) Transformers/TRL: output_dir/checkpoint-*
artifact(模型包) 训练导出的可交付目录 推理服务直接加载的路径 save_pretrained 目录结构
device_map 训练/推理加载时的设备放置策略 服务端决定权重落 GPU/CPU/offload 的方式 Transformers: device_map="auto"
torch_dtype 训练计算 dtype(fp16/bf16/fp32) 推理加载 dtype(影响显存与速度) Transformers: torch_dtype="auto"
TTFT / TPOT 训练不直接出现 首 token 延迟 / 每 token 延迟,衡量推理体验 vLLM: /metrics + 业务侧统计
topK / rerank 训练侧用于召回模型/排序模型的训练 RAG 检索阶段:ANN 召回 topK,reranker 取 topN 向量库 search + Cross-Encoder rerank
目录/产物/命令速查表
常见目录与产物
条目 建议路径/命名 由谁产出 给谁消费
数据集 jsonl data/processed/train.jsonl prepare_data.py datasets.load_dataset / DataLoader
数据集元信息 data/processed/dataset_meta.json prepare_data.py 训练/回溯/审计
训练 run 目录 outputs/runs/<run_id>/ 训练脚本 排障、对比、复现
checkpoint outputs/runs/<run_id>/checkpoints/checkpoint-* Trainer/TRL/DeepSpeed 断点续训、best 导出
可部署模型包 models/registry/model_vXXXX/ save_pretrained / merge 导出脚本 vLLM/Transformers 推理
生产指针(回滚开关) models/prod -> models/registry/model_vXXXX 上线/回滚脚本 serve.sh
常用命令(最小可复制)
目的 命令 备注
多卡启动训练
Shell
1
2
accelerate config
accelerate launch train.py
Accelerate 负责进程编排;脚本内部用同一份 PyTorch 代码
DeepSpeed 启动训练
Shell
1
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json
依赖 ds_config.json 固化 ZeRO/offload/AMP 等
启动 vLLM 服务
Shell
1
2
3
4
vllm serve /abs/path/to/models/prod \
  --host 0.0.0.0 --port 8000 \
  --served-model-name prod \
  --api-key token-abc123
OpenAI-compatible server;用 served-model-name 统一客户端 model 名
检查服务与指标
Shell
1
2
curl -sf http://127.0.0.1:8000/v1/models
curl -sf http://127.0.0.1:8000/metrics | head
最小 smoke test
启动 Qdrant(开发)
Shell
1
docker run -p 6333:6333 -v "$(pwd)/qdrant_storage:/qdrant/storage:z" qdrant/qdrant
默认无鉴权;生产需加安全配置
启用 pgvector
Shell
1
psql -d your_db -c "CREATE EXTENSION IF NOT EXISTS vector;"
扩展启用后再建表/建索引
安装 OpenMMLab 训练底座
Shell
1
2
pip install -U openmim
mim install mmengine
OpenMMLab 生态多仓库复用
← 人工智能知识 - 智能体

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code class="" title="" data-url=""> <del datetime=""> <em> <i> <q cite=""> <strike> <strong> <pre class="" title="" data-url=""> <span class="" title="" data-url="">

Related Posts

  • 2019年12月江南

Recent Posts

  • 人工智能知识 - 编程
  • 人工智能知识 - 智能体
  • 人工智能知识 - Transformers和大模型
  • 人工智能知识 - 主要应用领域
  • 人工智能知识 - 算法和机器学习
ABOUT ME

汪震 | Alex Wong

江苏淮安人,现居北京。目前供职于腾讯云,专注国际售后AI落地。

GitHub:gmemcc

Git:git.gmem.cc

Email:gmemjunk@gmem.cc@me.com

ABOUT GMEM

绿色记忆是我的个人网站,域名gmem.cc中G是Green的简写,MEM是Memory的简写,CC则是我的小天使彩彩名字的简写。

我在这里记录自己的工作与生活,同时和大家分享一些编程方面的知识。

GMEM HISTORY
v2.00:微风
v1.03:单车旅行
v1.02:夏日版
v1.01:未完成
v0.10:彩虹天堂
v0.01:阳光海岸
MIRROR INFO
Meta
  • Log in
  • Entries RSS
  • Comments RSS
  • WordPress.org
Recent Posts
  • 人工智能知识 - 编程
    这一篇专门处理 AI 训练、微调、推理与部署中的编程栈问题。前几篇分别讲了机器学习基础、任务版图、Transfo ...
  • 人工智能知识 - 智能体
    这一篇处理模型之外的系统层问题,包括上下文工程、Harness Engineering、检索增强生成(RAG)与 ...
  • 人工智能知识 - Transformers和大模型
    这一篇聚焦现代大模型主线,内容从 Transformer 架构出发,延伸到语言模型、多模态模型、预训练与微调,以 ...
  • 人工智能知识 - 主要应用领域
    这一篇从任务视角进入现代 AI 的几个核心应用方向,重点讨论自然语言处理、计算机视觉、语音和音频处理、搜索/推荐 ...
  • 人工智能知识 - 算法和机器学习
    这一篇从常用算法进入机器学习基础概念、经典机器学习与神经网络,重点讨论“模型如何被构造、训练、评估与正则化”。前 ...
  • 人工智能知识 - 数学基础
    这一篇整理 AI 所需的数学基础,包括基础数学、线性代数、微积分与概率论统计。它回答的核心问题是:模型里的向量、 ...
  • 人工智能知识 - 简介
    这一篇作为整套 AI 总纲的导论,先不进入公式和具体模型细节,而是回答更根本的问题:什么叫智能,人工智能究竟在试 ...
  • 多语言敏感信息检测模型训练日志
    这篇文章记录一个多语言敏感信息识别项目的完整训练日志。它关注的是工程路径本身:原始 AI 合成语料如何被清洗成可 ...
  • DevPod on Kubernetes: turning devcontainer.json into a persistent remote workspace
    DevPod is an open source workspace manager ...
  • OpenClaw: Architecture, Components, and Deployment Notes
    Four Months, 343,000 Stars On November 24, 2025, ...
  • Replacing Docker Desktop with Colima on macOS
    Colima is one of the cleanest ways ...
  • Kubernetes GPU Sharing
    GPU sharing in Kubernetes depends on what ...
  • Investigating and Solving the Issue of Failed Certificate Request with ZeroSSL and Cert-Manager
    In this blog post, I will walk ...
  • A Comprehensive Study of Kotlin for Java Developers
    Introduction Purpose of the Study Understanding the Mo ...
  • LangChain: Architecture, LCEL, Agents, LangGraph, Retrieval, and Production Patterns
    LangChain is no longer best understood as ...
  • Kubernetes Migration
    Migrating a Kubernetes cluster from one cloud ...
  • Terraform: a practical guide to infrastructure as code
    Terraform is an infrastructure-as-code tool. You describ ...
  • 草缸2021
    经过四个多月的努力,我的小小荷兰景到达极致了状态。

TOPLINKS
  • Zitahli's blue 91 people like this
  • 梦中的婚礼 64 people like this
  • 汪静好 61 people like this
  • 那年我一岁 36 people like this
  • 为了爱 28 people like this
  • 小绿彩 26 people like this
  • 杨梅坑 6 people like this
  • 亚龙湾之旅 1 people like this
  • 汪昌博 people like this
  • 彩虹姐姐的笑脸 24 people like this
  • 2013年11月香山 10 people like this
  • 2013年7月秦皇岛 6 people like this
  • 2013年6月蓟县盘山 5 people like this
  • 2013年2月梅花山 2 people like this
  • 2013年淮阴自贡迎春灯会 3 people like this
  • 2012年镇江金山游 1 people like this
  • 2012年徽杭古道 9 people like this
  • 2011年清明节后扬州行 1 people like this
  • 2008年十一云龙公园 5 people like this
  • 2008年之秋忆 7 people like this
  • 老照片 13 people like this
  • 火一样的六月 16 people like this
  • 发黄的相片 3 people like this
  • Cesium学习笔记 90 people like this
  • IntelliJ IDEA知识集锦 59 people like this
  • Bazel学习笔记 38 people like this
  • 基于Kurento搭建WebRTC服务器 38 people like this
  • PhoneGap学习笔记 32 people like this
  • NaCl学习笔记 32 people like this
  • 使用Oracle Java Mission Control监控JVM运行状态 29 people like this
  • Ceph学习笔记 27 people like this
  • 基于Calico的CNI 27 people like this
  • Three.js学习笔记 24 people like this
Tag Cloud
ActiveMQ AspectJ CDT Ceph Chrome CNI Command Cordova Coroutine CXF Cygwin DNS Docker eBPF Eclipse ExtJS F7 FAQ Groovy Hibernate HTTP IntelliJ IO编程 IPVS JacksonJSON JMS JSON JVM K8S kernel LB libvirt Linux知识 Linux编程 LOG Maven MinGW Mock Monitoring Multimedia MVC MySQL netfs Netty Nginx NIO Node.js NoSQL Oracle PDT PHP Redis RPC Scheduler ServiceMesh SNMP Spring SSL svn Tomcat TSDB Ubuntu WebGL WebRTC WebService WebSocket wxWidgets XDebug XML XPath XRM ZooKeeper 亚龙湾 单元测试 学习笔记 实时处理 并发编程 彩姐 性能剖析 性能调优 文本处理 新特性 架构模式 系统编程 网络编程 视频监控 设计模式 远程调试 配置文件 齐塔莉
Recent Comments
  • 杨松涛 on snmp4j学习笔记
  • kaka on Cilium学习笔记
  • JackZhouMine on Cesium学习笔记
  • 陈黎 on 通过自定义资源扩展Kubernetes
  • qg on Istio中的透明代理问题
  • heao on 基于本地gRPC的Go插件系统
  • 黄豆豆 on Ginkgo学习笔记
  • cloud on OpenStack学习笔记
  • 5dragoncon on Cilium学习笔记
  • Archeb on 重温iptables
  • C/C++编程:WebSocketpp(Linux + Clion + boostAsio) – 源码巴士 on 基于C/C++的WebSocket库
  • jerbin on eBPF学习笔记
  • point on Istio中的透明代理问题
  • G on Istio中的透明代理问题
  • 绿色记忆:Go语言单元测试和仿冒 on Ginkgo学习笔记
  • point on Istio中的透明代理问题
  • 【Maven】maven插件开发实战 – IT汇 on Maven插件开发
  • chenlx on eBPF学习笔记
  • Alex on eBPF学习笔记
  • CFC4N on eBPF学习笔记
  • 李运田 on 念爷爷
  • yongman on 记录一次KeyDB缓慢的定位过程
  • Alex on Istio中的透明代理问题
©2005-2026 Gmem.cc | Powered by WordPress | 京ICP备18007345号-2