Pytorch 7 :Memory Optimization(Freeing GPU/NPU Memory Early)
Python 的自动内存管理主要基于引用计数(Reference Counting),辅以循环垃圾回收器(Cycle Garbage Collector)处理循环引用。
引用计数的核心思想是:每个对象维护一个计数器,记录有多少变量(或容器)引用它;当计数降为 0 时,对象立即被销毁并释放内存。
一、引用计数机制简述
- 每次一个对象被引用(赋值、传参、放入容器等),其
ob_refcnt+1。 - 每次引用被删除(变量离开作用域、被覆盖、从容器中移除等),
ob_refcnt-1。 - 当
ob_refcnt == 0,Python 调用该对象的析构函数(如果有的话)并释放内存。
你可使用 sys.getrefcount(obj) 查看当前引用计数(注意:该函数本身会临时增加一次引用,所以返回值比实际多 1)。
函数调用返回
引用计数 不是全局参数,而是 每个 Python 对象自身的一个属性(存储在 PyObject 结构体的 ob_refcnt 字段中)。它跟踪的是 有多少个“引用”指向该对象,与变量作用域(局部/全局)无关,只与“引用关系”有关。
易混淆点:
1. 引用计数是全局参数吗?还是不同函数的局部变量?
都不是。
引用计数属于 对象本身,不是变量的属性,也不是全局或局部变量。例如:
1
2a = [1, 2, 3] # 创建一个 list 对象,引用计数 = 1
b = a # 同一个 list 对象,引用计数 = 2这里的引用计数是
[1,2,3]这个 list 对象的属性,a和b只是两个名字(引用)。
2. 如果一个函数使用了一个变量,其引用计数会增加吗?
会,但要看“使用”的方式:
情况 A:将对象作为参数传入函数
1 | import sys |
- 在函数调用时,形参
x会绑定到obj所指向的对象 → 引用计数 临时 +1。 - 函数返回后,形参
x超出作用域 → 引用计数 -1,恢复原状。
✅ 结论:函数参数传递会临时增加引用计数,但函数结束时会自动减少。
情况 B:函数内部创建新对象
1 | def g(): |
y是局部变量,指向新列表。return y将引用传递给调用者(z),不是复制对象。- 函数结束时,
y被销毁(引用 -1),但z接管了引用(总引用数保持 1)。
3. 函数返回后,局部变量和全局变量的引用计数如何变化?
(1)局部变量:
- 函数执行结束时,所有局部变量(如
x,y)从局部命名空间中移除。 - 这会导致它们所引用的对象 引用计数 -1。
- 如果计数变为 0,对象立即被回收。
(2)全局变量:
- 如果函数内部读取全局变量(如
global_list),会临时增加其引用计数(因为函数栈帧中有一个引用)。 - 函数结束后,这个临时引用消失,计数恢复。
- 如果函数内部修改全局变量(
global global_list),则可能增加/减少引用,取决于操作(如赋值新对象)。
总结
| 场景 | 引用计数变化 |
|---|---|
| 函数参数传入对象 | +1(临时),函数返回后 -1 |
| 函数内局部变量引用对象 | 作用域结束时 -1 |
return obj |
调用者获得引用,对象生命周期延长 |
| 全局变量被函数读取 | 临时 +1,函数结束 -1 |
| 对象引用计数归零 | 立即析构(__del__ + 内存释放) |
这正是 Python 能高效管理内存(尤其在训练循环中频繁创建中间张量)的基础机制之一。
del 的使用
非常好的问题!我们来精确澄清 del 的作用:
del x并不是直接将引用计数归零,而是将变量x对其所指向对象的引用移除,从而导致该对象的引用计数减 1(refcnt -= 1)。
一、del 的本质:解除一个引用
在 Python 中,变量(如 x)本质上是名字(name)到对象的引用(reference)。
执行 del x 的含义是:
- **从当前作用域(如局部或全局命名空间)中删除名字
x**; - 这会导致
x原本指向的对象失去一次引用; - 该对象的引用计数减 1;
- 是否归零,取决于是否有其他引用存在。
二、举例说明
✅ 情况 1:只有一个引用 → del 后引用计数归零
1 | import sys |
→ 这里 del a 使引用计数从 1 → 0,对象销毁。
✅ 情况 2:多个引用 → del 后引用计数减 1,但不归零
1 | import sys |
→ del a 仅移除了 a 这个引用,对象仍被 b 持有。
✅ 情况 3:函数局部变量(自动 del)
1 | import sys |
→ 函数结束时,局部变量 x 被隐式 del,引用计数 -1,但因 return 被 obj 接收,总引用数不变。
三、与 PyTorch / NPU 的关联
1 | x = torch.randn(1000, device='npu') # 引用计数 = 1 |
✅ 所以,**
del只是减 1,是否归零取决于其他引用是否存在**。
四、常见误解纠正
| 误解 | 正确理解 |
|---|---|
“del x 会删除对象” |
❌ del x 只删除名字 x,对象是否删除看引用计数 |
“del 后内存立刻释放” |
❌ 仅当引用计数归零时才释放 Python 对象;设备内存还受框架缓存控制 |
“del = free()” |
❌ Python 是垃圾回收语言,del 不是直接内存释放指令 |
总结
del x→ 移除变量x对对象的引用 → 对象引用计数减 1。- 是否归零? → 取决于是否还有其他变量/容器/闭包等引用该对象。
- 设备内存释放? → 需要:
引用计数归零→Tensor 析构→PyTorch 缓存池标记为可回收→empty_cache()(可选,归还给驱动)。
empty_cache
torch.npu.empty_cache())
PyTorch 不会在 Tensor 被销毁时立即释放 GPU/NPU 内存,而是:
- 将释放的内存块放入 内存池(caching allocator)。
- 后续分配相同大小的内存时,直接复用缓存块,避免频繁调用
cudaFree(昂贵操作)。
empty_cache() 的作用:
将缓存中未被使用的内存块真正归还给设备驱动(如 CUDA driver 或 NPU driver)。
gc.collect()
- 作用:运行 Python 的 循环垃圾回收器(Generational GC),检测并回收引用计数无法释放的循环引用对象。
- 触发条件:
- 自动:每分配一定数量对象后触发(可配置)。
- 手动:调用
gc.collect()。
- 与引用计数关系:
- 补充机制:只处理引用计数“漏掉”的对象(主要是循环引用)。
- 不管理设备内存:即使回收了 Tensor 的 Python 对象,若 PyTorch 仍持有其底层数据指针,显存/NPU 内存不会释放。
1 | import gc |
❌
gc.collect()不能释放 GPU/NPU 显存,除非它成功回收了 Tensor 对象,且该 Tensor 是最后一个持有设备内存的引用。
计算图引用:batch数据被保存到反向
这是一个在 PyTorch 训练循环中极易被忽视但影响显著的内存问题。确实,DataLoader 通过 for batch in dataloader 产出每个 batch,但如果处理不当,一个 batch 的数据可能被意外持有多轮(甚至直到反向传播结束),导致 GPU/NPU 内存无法及时释放,尤其在 VLM、MoE 等大模型场景中会严重限制 batch size 或引发 OOM。
🔍 问题本质:谁“持有”了 batch?
1 | for batch in dataloader: |
表面上看,batch 是 for 循环的局部变量,每轮应自动释放。
但问题出在计算图(computation graph)和 PyTorch 的自动微分机制。
🧠 核心原因:计算图保留了对输入 Tensor 的引用
当 requires_grad=True(或模型有可训练参数)时,PyTorch 会构建反向传播所需的计算图。
该图会 隐式持有对所有参与 forward 的叶节点(leaf tensors)的引用,包括:
pixel_valuesinput_ids- 以及其他输入张量
✅ 即使你在
forward结尾del pixel_values,只要计算图存在,PyTorch 仍会保持对其底层数据的引用 → 设备内存无法释放。
这个引用会一直持续到 loss.backward() 执行完毕(或手动 loss.grad_fn 被断开)。
✅ 验证:为什么 batch 会被持到反向?
1 | import torch |
- 在
backward()之前,x被计算图引用 → 无法释放。 backward()后,计算图被销毁(除非retain_graph=True)→x的“图引用”消失。- 如果
x没有其他 Python 引用(如变量、容器),则其内存可被回收。
🚫 为什么 del batch 在循环内无效?
1 | for batch in dataloader: |
batch是 CPU 上的字典,del batch仅释放 CPU 内存。x是独立的 GPU Tensor,其生命周期由:- Python 引用(
x变量) - 计算图引用(关键!)
- Python 引用(
共同决定。
✅ 正确解决方案
✅ 方案 1:在 backward() 后立即清理(最常用)
1 | for i, batch in enumerate(dataloader): |
✅
backward()后计算图销毁 → 张量仅剩 Python 引用 → 出循环时自动释放。
✅ 方案 2:使用 no_grad() 或 torch.inference_mode()(推理/评估时)
1 | with torch.inference_mode(): # 或 torch.no_grad() |
- 推理时无反向传播 → 无计算图 →
del或作用域退出即可释放。
✅ 方案 3:避免在循环外意外持有引用
常见陷阱:
1 | all_losses = [] |
→ 不要将 loss/Tensor 存入全局 list,除非 .item() 或 .detach():
1 | all_losses.append(loss.item()) # ✅ 只存标量,不保留图 |
✅ 方案 4:使用 model(...).detach()(谨慎)
如果某些中间输出需保留但不参与反向:
1 | features = model.vision_encoder(pixel_values).detach() |
.detach()创建一个与计算图断开的新 Tensor,不持有图引用。
📊 内存释放时机总结
| 阶段 | 是否可释放 pixel_values 的设备内存? |
原因 |
|---|---|---|
forward 中(反向前) |
❌ 不能 | 计算图持有引用 |
del pixel_values(反向前) |
❌ 不能 | 图引用仍存在 |
loss.backward() 后 |
✅ 可以 | 计算图已销毁 |
下一轮 for 循环开始前 |
✅ 自动释放 | Python 引用退出 + 无图引用 |
💡 针对你 NPU/VLM 场景的建议
- 不要尝试在
forward中提前释放pixel_values以节省反向阶段内存 —— 这是徒劳的,因为计算图锁住了它。 - 真正有效的优化方向:
- 减少 计算图规模(如 gradient checkpointing)
- 使用
torch.compile或 融合算子(如 GMM) 降低中间激活内存 - 避免 意外保留 Tensor(如 logger、debug 变量)
- 监控工具:
1
2print(torch.npu.memory_allocated()) # 活跃内存
print(torch.npu.memory_reserved()) # 总保留内存(含缓存)
✅ 结论
DataLoader的 batch 被持到反向是合理的,因为 PyTorch 的自动微分机制需要它。
这不是 bug,而是 反向传播的必要代价。
你无法也不应该在backward()前释放参与计算图的输入 Tensor。
正确做法是:确保反向后无冗余引用,并优化计算图本身(如 activation checkpointing)。
Pytorch 7 :Memory Optimization(Freeing GPU/NPU Memory Early)