Appearance
CS 336 Lecture 2: PyTorch 与资源统计 (Resource Accounting)
内存统计:数据类型 (Memory Counting: Datatypes)
在深度学习中,显存的占用很大程度上取决于所选用的数据类型(Precision)。
- float32 (fp32 / 单精度)
- 占用 4 Bytes。
- 结构:1 bit 符号位,8 bits 指数位,23 bits 尾数位。
- 特点:深度学习中的最高精度标准,通常用于权重更新的核心计算。
- float16 (fp16 / 半精度)
- 占用 2 Bytes。
- 结构:1 bit 符号位,5 bits 指数位,10 bits 尾数位。
- 特点:节省显存,但表示范围较窄,容易出现数值溢出(Overflow)。
- bfloat16 (bf16 / 大脑半精度)
- 占用 2 Bytes。
- 结构:1 bit 符号位,8 bits 指数位,7 bits 尾数位。
- 特点:由 Google 提出,指数位与 fp32 一致,虽然精度稍低,但动态范围与 fp32 相同,训练稳定性远好于 fp16。
- float8 (fp8)
- 占用 1 Byte。
- 变体:E4M3(高精度,用于前向/反向传播)和 E5M2(宽范围,用于梯度/状态)。
- 混合精度训练 (Mixed Precision Training)
- 由于低精度训练可能带来不稳定性,通常采用前向传播使用低精度(如 bf16/fp8),而权重更新和某些核心累加使用高精度(fp32)的方案。
计算统计:Einops 库 (Compute Counting: Einops)
einops(Einstein Summation Notation)提供了一种简洁的方式来标注和操作 PyTorch 张量的维度信息。
einsum:爱因斯坦求和约定
相比传统的转置与矩阵乘法,einsum 能够自动处理维度的求和与广播。
python
x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2,3,4)
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2,3,4)
# 传统方式
z = x @ y.transpose(-2,-1) # [batch, seq1, seq2]
# einops 方式
# einsum 会对结果中没有提及的维度(hidden)自动求和
z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")
# 使用 ... 代表广播任意维
z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")reduce:张量降维
reduce 可以直接对指定维度进行聚合操作(如 mean, sum, max)。
python
x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2,3,4)
# 传统方式
y = x.mean(dim = -1)
# einops 方式
y = reduce(x, "... hidden -> ...", "mean")rearrange:维度变换与重塑
这是 einops 最强大的功能,可以直观地处理维度的拆分与合并。
python
x: Float[torch.Tensor, "batch seq1 total_hidden"] = torch.ones(2,3,8)
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4,4)
# 将 total_hidden 拆分为多头形式
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)
# 进行矩阵变换
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")
# 将多维合并回一维
x = rearrange(x, "... heads hidden1 -> ... (heads hidden1)")FLOPs 统计 (FLOPs Counting)
张量乘法的前向传播
对于矩阵乘法
在神经网络(如 MLP)中,若
计算时间与利用率 (MFU)
- FLOPS:每秒浮点运算次数,取决于硬件性能和数据类型。例如 H100 在 FP32 下仅为
。 - MFU (Model FLOPs Utilization):模型 FLOPs 利用率。
- 定义:
- 经验值:MFU
被视为非常优秀,通信开销和内存墙是主要的性能限制。
- 定义:

梯度反向传播 (Gradients FLOPs)
在反向传播中,我们需要计算:
- 权重梯度 (Weight Grad):计算复杂度约为
。 - 激活梯度 (Activation Grad):计算复杂度约为
。
因此,反向传播的 FLOPs 是前向传播的 2 倍:
训练总 FLOPs 统计:
训练资源管理与优化
模型初始化 (Model Initialization)
为了防止输出在多层叠加后数值爆炸或消失,通常使用 Xavier 初始化(或其他缩放方案):
python
w = nn.Parameter(torch.randn(input_dim, hidden_dim) / np.sqrt(input_dim))优化器内存占用 (Optimizer Memory)
在训练结束后,可以通过设置 set_to_none=True 来释放梯度占用的显存:
python
optimizer.zero_grad(set_to_none=True)显存记账公式 (Memory Accounting)
总显存消耗取决于参数、激活值、梯度以及优化器状态的总和:
(例如 Adam 需要 2 份状态)
总内存估算:
精度选择策略
- 前向传播 (Forward):建议使用 {bf16, fp8} 以减少显存和加速计算。
- 反向传播与权重更新 (Backward & Update):建议在关键累加环节使用 {float32} 以保证数值稳定性。
- 策略总结:低精度训练难度较大(需要 Loss Scaling 等技巧),但推理时的量化(Quantization)相对容易。