Pytorch 7 :Memory Optimization(Freeing GPU/NPU Memory Early)

导言

  • 对于不使用的python对象,如何释放?
  • python 的对象管理机制
  • del,empty_cache , gc_collect的原理

Python 的自动内存管理主要基于引用计数(Reference Counting),辅以循环垃圾回收器(Cycle Garbage Collector)处理循环引用。
引用计数的核心思想是:每个对象维护一个计数器,记录有多少变量(或容器)引用它;当计数降为 0 时,对象立即被销毁并释放内存


一、引用计数机制简述

  • 每次一个对象被引用(赋值、传参、放入容器等),其 ob_refcnt +1。
  • 每次引用被删除(变量离开作用域、被覆盖、从容器中移除等),ob_refcnt -1。
  • ob_refcnt == 0,Python 调用该对象的析构函数(如果有的话)并释放内存。

你可使用 sys.getrefcount(obj) 查看当前引用计数(注意:该函数本身会临时增加一次引用,所以返回值比实际多 1)。


结合 fp32bfloat16(bf16)转换举例

虽然 Python 本身不直接操作 bf16(这是底层硬件/框架如 PyTorch/TensorFlow 支持的数据类型),但我们可以用 NumPy 或 PyTorch 张量来模拟对象生命周期,观察引用计数变化。

注意:bf16 在 PyTorch 中通过 .to(torch.bfloat16) 实现,但其底层张量对象仍是 Python 对象,受引用计数管理。


代码示例:观察引用计数变化(以 PyTorch 张量为例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import sys

# 创建一个 fp32 张量
x_fp32 = torch.randn(1000, dtype=torch.float32)
print(f"创建 x_fp32 后引用计数: {sys.getrefcount(x_fp32) - 1}") # -1 排除 getrefcount 自身引用

# 转换为 bf16,生成新张量(注意:这是新对象!)
x_bf16 = x_fp32.to(torch.bfloat16)
print(f"x_bf16 创建后,x_fp32 引用计数: {sys.getrefcount(x_fp32) - 1}")
print(f"x_bf16 的引用计数: {sys.getrefcount(x_bf16) - 1}")

# 删除原始 fp32 张量
del x_fp32
# 此时 x_fp32 对象的引用计数降为 0(假设无其他引用),内存被立即回收
print("已删除 x_fp32")

# 仍可使用 x_bf16
print(f"x_bf16 数据类型: {x_bf16.dtype}")

# 删除 x_bf16
del x_bf16
# 现在 x_bf16 对象也被回收

输出(典型情况):

1
2
3
4
5
创建 x_fp32 后引用计数: 1
x_bf16 创建后,x_fp32 引用计数: 1
x_bf16 的引用计数: 1
已删除 x_fp32
x_bf16 数据类型: torch.bfloat16

四、关键点说明

  1. to() 不是原地操作x_fp32.to(torch.bfloat16) 返回新张量对象,与原 x_fp32 无共享内存(除非显式使用 viewas_strided,但 bf16 与 fp32 位宽不同,无法直接 view)。
  2. 两个独立对象x_fp32x_bf16 各自有自己的引用计数。
  3. del 降低引用计数del x_fp32 使 x_fp32 的引用计数减 1,若变为 0,则立即触发 __del__(如果定义了)并释放内存。
  4. NPU/GPU 内存也受此机制管理:PyTorch 张量在 CPU/NPU/GPU 上的数据由 Python 对象持有,当对象被回收,其底层设备内存也会被释放(通过张量的析构函数)。

函数调用返回

引用计数 不是全局参数,而是 每个 Python 对象自身的一个属性(存储在 PyObject 结构体的 ob_refcnt 字段中)。它跟踪的是 有多少个“引用”指向该对象,与变量作用域(局部/全局)无关,只与“引用关系”有关。

易混淆点:


1. 引用计数是全局参数吗?还是不同函数的局部变量?

  • 都不是
    引用计数属于 对象本身,不是变量的属性,也不是全局或局部变量。

    例如:

    1
    2
    a = [1, 2, 3]  # 创建一个 list 对象,引用计数 = 1
    b = a # 同一个 list 对象,引用计数 = 2

    这里的引用计数是 [1,2,3] 这个 list 对象的属性,ab 只是两个名字(引用)。


2. 如果一个函数使用了一个变量,其引用计数会增加吗?

会,但要看“使用”的方式

情况 A:将对象作为参数传入函数

1
2
3
4
5
6
7
8
9
10
import sys

def f(x):
print("函数内引用计数:", sys.getrefcount(x) - 1) # -1 因为 getrefcount 自身加1
return x

obj = [1, 2, 3]
print("调用前引用计数:", sys.getrefcount(obj) - 1) # 通常是 1
f(obj)
print("调用后引用计数:", sys.getrefcount(obj) - 1) # 仍是 1
  • 在函数调用时,形参 x 会绑定到 obj 所指向的对象 → 引用计数 临时 +1
  • 函数返回后,形参 x 超出作用域 → 引用计数 -1,恢复原状。

结论:函数参数传递会临时增加引用计数,但函数结束时会自动减少。

情况 B:函数内部创建新对象

1
2
3
4
5
def g():
y = [4, 5, 6] # 新对象,引用计数 = 1(仅 y 引用它)
return y

z = g() # z 接收返回值 → 对象引用计数仍为 1(y 已销毁,z 引用它)
  • y 是局部变量,指向新列表。
  • return y 将引用传递给调用者(z),不是复制对象
  • 函数结束时,y 被销毁(引用 -1),但 z 接管了引用(总引用数保持 1)。

3. 函数返回后,局部变量和全局变量的引用计数如何变化?

(1)局部变量

  • 函数执行结束时,所有局部变量(如 x, y从局部命名空间中移除
  • 这会导致它们所引用的对象 引用计数 -1
  • 如果计数变为 0,对象立即被回收。

(2)全局变量

  • 如果函数内部读取全局变量(如 global_list),会临时增加其引用计数(因为函数栈帧中有一个引用)。
  • 函数结束后,这个临时引用消失,计数恢复。
  • 如果函数内部修改全局变量(global global_list),则可能增加/减少引用,取决于操作(如赋值新对象)。

完整示例:观察函数调用中的引用计数变化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import sys

global_obj = [10, 20]

def example(local_obj):
print("进入函数时 global_obj 引用计数:", sys.getrefcount(global_obj) - 1)
print("进入函数时 local_obj 引用计数:", sys.getrefcount(local_obj) - 1)

temp = local_obj # 引用 +1
print("赋值 temp 后 local_obj 引用计数:", sys.getrefcount(local_obj) - 1)

return temp

my_list = [1, 2]
print("调用前 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 1

result = example(my_list)

print("函数返回后 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 2(my_list + result)
print("函数返回后 global_obj 引用计数:", sys.getrefcount(global_obj) - 1) # 1(恢复)

del result # 手动删除 result
print("删除 result 后 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 1

关键点

  • 函数参数、局部变量都是引用,不是对象本身。
  • 函数栈帧存在期间,会对所引用的对象保持一次引用。
  • 函数返回时,局部变量销毁 → 引用计数减少
  • 如果 return 了某个对象,调用者接收到引用,对象不会被销毁。

特别注意:在 PyTorch / NumPy 中的含义

在你关注的 AI 框架中(如 PyTorch 张量):

  • def forward(x): ... return y 中的 xy 都是张量对象的引用。
  • 函数结束时,局部变量(如中间激活 h = relu(x))若未被返回或保存,引用计数降为 0,张量对象被回收,其 GPU/NPU 内存也可能释放(取决于框架内存池策略)。
  • 但若计算图需要(如 requires_grad=True),PyTorch 会额外持有引用用于反向传播,此时即使局部变量销毁,张量也不会被回收。

总结

场景 引用计数变化
函数参数传入对象 +1(临时),函数返回后 -1
函数内局部变量引用对象 作用域结束时 -1
return obj 调用者获得引用,对象生命周期延长
全局变量被函数读取 临时 +1,函数结束 -1
对象引用计数归零 立即析构__del__ + 内存释放)

这正是 Python 能高效管理内存(尤其在训练循环中频繁创建中间张量)的基础机制之一。

del 的使用

非常好的问题!我们来精确澄清 del 的作用

del x 并不是直接将引用计数归零,而是将变量 x 对其所指向对象的引用移除,从而导致该对象的引用计数减 1(refcnt -= 1


一、del 的本质:解除一个引用

在 Python 中,变量(如 x)本质上是名字(name)到对象的引用(reference)
执行 del x 的含义是:

  • **从当前作用域(如局部或全局命名空间)中删除名字 x**;
  • 这会导致 x 原本指向的对象失去一次引用
  • 该对象的引用计数减 1
  • 是否归零,取决于是否有其他引用存在

二、举例说明

✅ 情况 1:只有一个引用 → del 后引用计数归零

1
2
3
4
5
6
7
8
import sys

a = [1, 2, 3]
print(sys.getrefcount(a) - 1) # 输出 1(只有 a 引用它)

del a
# 此时 [1,2,3] 对象的引用计数变为 0,立即被回收
# 再访问 a 会报 NameError: name 'a' is not defined

→ 这里 del a 使引用计数从 1 → 0,对象销毁。


✅ 情况 2:多个引用 → del 后引用计数减 1,但不归零

1
2
3
4
5
6
7
8
9
10
import sys

a = [1, 2, 3]
b = a # b 也引用同一个 list
print(sys.getrefcount(a) - 1) # 输出 2(a 和 b)

del a
print(sys.getrefcount(b) - 1) # 输出 1(仅剩 b)
# 对象未被销毁,仍可通过 b 访问
print(b) # [1, 2, 3]

del a 仅移除了 a 这个引用,对象仍被 b 持有。


✅ 情况 3:函数局部变量(自动 del

1
2
3
4
5
6
7
8
9
10
import sys

def f():
x = [1]
print("函数内引用计数:", sys.getrefcount(x) - 1) # 1
return x

obj = f() # 函数返回后,局部变量 x 自动被 del
# x 的引用消失,但 obj 接管了引用,所以对象未销毁
print("函数外引用计数:", sys.getrefcount(obj) - 1) # 1

→ 函数结束时,局部变量 x 被隐式 del,引用计数 -1,但因 returnobj 接收,总引用数不变。


三、与 PyTorch / NPU 的关联

1
2
3
4
5
6
7
8
9
x = torch.randn(1000, device='npu')  # 引用计数 = 1
y = x # 引用计数 = 2

del x # 引用计数 = 1(y 仍持有)
# 此时 NPU 内存不会释放!

del y # 引用计数 = 0
# Tensor 对象销毁 → PyTorch 将其设备内存归还到缓存池
# 调用 torch.npu.empty_cache() 才会真正释放给 NPU 驱动

✅ 所以,**del 只是减 1,是否归零取决于其他引用是否存在**。


四、常见误解纠正

误解 正确理解
del x 会删除对象” del x 只删除名字 x,对象是否删除看引用计数
del 后内存立刻释放” ❌ 仅当引用计数归零时才释放 Python 对象;设备内存还受框架缓存控制
del = free()” ❌ Python 是垃圾回收语言,del 不是直接内存释放指令

总结

  • del x移除变量 x 对对象的引用对象引用计数减 1
  • 是否归零? → 取决于是否还有其他变量/容器/闭包等引用该对象。
  • 设备内存释放? → 需要:引用计数归零Tensor 析构PyTorch 缓存池标记为可回收empty_cache()(可选,归还给驱动)。

empty_cache

torch.npu.empty_cache()

PyTorch 不会在 Tensor 被销毁时立即释放 GPU/NPU 内存,而是:

  • 将释放的内存块放入 内存池(caching allocator)
  • 后续分配相同大小的内存时,直接复用缓存块,避免频繁调用 cudaFree(昂贵操作)。

empty_cache() 的作用:

将缓存中未被使用的内存块真正归还给设备驱动(如 CUDA driver 或 NPU driver)。

gc.collect()

  • 作用:运行 Python 的 循环垃圾回收器(Generational GC),检测并回收引用计数无法释放的循环引用对象
  • 触发条件
    • 自动:每分配一定数量对象后触发(可配置)。
    • 手动:调用 gc.collect()
  • 与引用计数关系
    • 补充机制:只处理引用计数“漏掉”的对象(主要是循环引用)。
    • 不管理设备内存:即使回收了 Tensor 的 Python 对象,若 PyTorch 仍持有其底层数据指针,显存/NPU 内存不会释放
1
2
3
4
5
6
import gc
a = []
b = [a]
a.append(b) # 循环引用:a → b → a
del a, b # 引用计数不会归零!
gc.collect() # 此时才能回收

gc.collect() 不能释放 GPU/NPU 显存,除非它成功回收了 Tensor 对象,且该 Tensor 是最后一个持有设备内存的引用。

计算图引用:batch数据被保存到反向

这是一个在 PyTorch 训练循环中极易被忽视但影响显著的内存问题。确实,DataLoader 通过 for batch in dataloader 产出每个 batch,但如果处理不当,一个 batch 的数据可能被意外持有多轮(甚至直到反向传播结束),导致 GPU/NPU 内存无法及时释放,尤其在 VLM、MoE 等大模型场景中会严重限制 batch size 或引发 OOM。


🔍 问题本质:谁“持有”了 batch?

1
2
3
4
5
6
7
8
9
for batch in dataloader:
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)

outputs = model(pixel_values, input_ids)
loss = outputs.loss

loss.backward() # ← 反向时,batch 数据仍被持有?
optimizer.step()

表面上看,batchfor 循环的局部变量,每轮应自动释放。
问题出在计算图(computation graph)和 PyTorch 的自动微分机制


🧠 核心原因:计算图保留了对输入 Tensor 的引用

requires_grad=True(或模型有可训练参数)时,PyTorch 会构建反向传播所需的计算图
该图会 隐式持有对所有参与 forward 的叶节点(leaf tensors)的引用,包括:

  • pixel_values
  • input_ids
  • 以及其他输入张量

即使你在 forward 结尾 del pixel_values,只要计算图存在,PyTorch 仍会保持对其底层数据的引用设备内存无法释放

这个引用会一直持续到 loss.backward() 执行完毕(或手动 loss.grad_fn 被断开)。


✅ 验证:为什么 batch 会被持到反向?

1
2
3
4
5
6
7
8
9
import torch

x = torch.randn(1000, requires_grad=True).cuda()
y = x * 2
loss = y.sum()

print("Before backward:", torch.cuda.memory_allocated()) # 高
loss.backward()
print("After backward:", torch.cuda.memory_allocated()) # 明显下降(x 的 grad 保留,但计算图释放)
  • backward() 之前,x 被计算图引用 → 无法释放。
  • backward() 后,计算图被销毁(除非 retain_graph=True)→ x 的“图引用”消失。
  • 如果 x 没有其他 Python 引用(如变量、容器),则其内存可被回收。

🚫 为什么 del batch 在循环内无效?

1
2
3
4
for batch in dataloader:
x = batch["x"].to("cuda")
del batch # ← 这只能删除 CPU 上的原始 dict,不释放 x!
...
  • batch 是 CPU 上的字典,del batch 仅释放 CPU 内存。
  • x 是独立的 GPU Tensor,其生命周期由:
    • Python 引用(x 变量)
    • 计算图引用(关键!)

共同决定。


✅ 正确解决方案

✅ 方案 1:backward() 后立即清理(最常用)

1
2
3
4
5
6
7
8
9
10
11
12
13
for i, batch in enumerate(dataloader):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"]...to(device)

outputs = model(pixel_values, input_ids)
loss = outputs.loss

loss.backward()
optimizer.step()
optimizer.zero_grad()

# 可选:显式 del(虽非必须,但更安全)
del pixel_values, input_ids, outputs, loss

backward() 后计算图销毁 → 张量仅剩 Python 引用 → 出循环时自动释放

✅ 方案 2:使用 no_grad()torch.inference_mode()(推理/评估时)

1
2
3
4
with torch.inference_mode():  # 或 torch.no_grad()
for batch in dataloader:
...
# 无计算图 → 输入 Tensor 用完即可释放
  • 推理时无反向传播 → 无计算图 → del 或作用域退出即可释放。

✅ 方案 3:避免在循环外意外持有引用

常见陷阱:

1
2
3
4
all_losses = []
for batch in dataloader:
loss = model(...)
all_losses.append(loss) # ← 保留了 loss,间接保留整个计算图!

不要将 loss/Tensor 存入全局 list,除非 .item().detach()

1
all_losses.append(loss.item())  # ✅ 只存标量,不保留图

✅ 方案 4:使用 model(...).detach()(谨慎)

如果某些中间输出需保留但不参与反向:

1
features = model.vision_encoder(pixel_values).detach()
  • .detach() 创建一个与计算图断开的新 Tensor,不持有图引用。

📊 内存释放时机总结

阶段 是否可释放 pixel_values 的设备内存? 原因
forward 中(反向前) ❌ 不能 计算图持有引用
del pixel_values(反向前) ❌ 不能 图引用仍存在
loss.backward() ✅ 可以 计算图已销毁
下一轮 for 循环开始前 ✅ 自动释放 Python 引用退出 + 无图引用

💡 针对你 NPU/VLM 场景的建议

  1. 不要尝试在 forward 中提前释放 pixel_values 以节省反向阶段内存 —— 这是徒劳的,因为计算图锁住了它。
  2. 真正有效的优化方向
    • 减少 计算图规模(如 gradient checkpointing)
    • 使用 torch.compile融合算子(如 GMM) 降低中间激活内存
    • 避免 意外保留 Tensor(如 logger、debug 变量)
  3. 监控工具
    1
    2
    print(torch.npu.memory_allocated())  # 活跃内存
    print(torch.npu.memory_reserved()) # 总保留内存(含缓存)

✅ 结论

DataLoader 的 batch 被持到反向是合理的,因为 PyTorch 的自动微分机制需要它
这不是 bug,而是 反向传播的必要代价
你无法也不应该在 backward() 前释放参与计算图的输入 Tensor。
正确做法是:确保反向后无冗余引用,并优化计算图本身(如 activation checkpointing)

避免保存原生特征

要注意 rollout buffer 中是否存储了原始 pixel_values —— 应只存 视觉特征(feature)token IDs,而非原始像素,这才是根本的内存优化。

Author

Shaojie Tan

Posted on

2025-11-26

Updated on

2025-11-28

Licensed under